{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# Merlin ETL, training, and inference with e-Commerce behavior data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "In this tutorial, we use the [eCommerce behavior data from multi category store](https://www.kaggle.com/mkechinov/ecommerce-behavior-data-from-multi-category-store) from [REES46 Marketing Platform](https://rees46.com/) as our dataset. This tutorial is built upon the NVIDIA RecSys 2020 [tutorial](https://recsys.acm.org/recsys20/tutorials/). \n",
    "\n",
    "This notebook provides the code to preprocess the dataset and generate the training, validation, and test sets for the remainder of the tutorial. We define our own goal and filter the dataset accordingly.\n",
    "\n",
    "For our tutorial, we decided that our goal is to predict if a user purchased an item:\n",
    "\n",
    "-  Positive: User purchased an item.\n",
    "-  Negative: User added an item to the cart, but did not purchase it (in the same session).\n",
    "\n",
    "We split the dataset into training, validation, and test set by the timestamp:\n",
    "\n",
    "- Training: October 2019 - February 2020\n",
    "- Validation: March 2020\n",
    "- Test: April 2020\n",
    "\n",
    "We remove AddToCart Events from a session, if in the same session the same item was purchased."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "First, we download and unzip the raw data.\n",
    "\n",
    "Note: the dataset is approximately 11 GB and can require several minutes to download."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "export HOME=$PWD\n",
    "pip install gdown --user\n",
    "~/.local/bin/gdown  https://drive.google.com/uc?id=1-Rov9fFtGJqb7_ePc6qH-Rhzxn0cIcKB\n",
    "~/.local/bin/gdown  https://drive.google.com/uc?id=1-Rov9fFtGJqb7_ePc6qH-Rhzxn0cIcKB\n",
    "~/.local/bin/gdown  https://drive.google.com/uc?id=1zr_RXpGvOWN2PrWI6itWL8HnRsCpyqz8\n",
    "~/.local/bin/gdown  https://drive.google.com/uc?id=1g5WoIgLe05UMdREbxAjh0bEFgVCjA1UL\n",
    "~/.local/bin/gdown  https://drive.google.com/uc?id=1qZIwMbMgMmgDC5EoMdJ8aI9lQPsWA3-P\n",
    "~/.local/bin/gdown  https://drive.google.com/uc?id=1x5ohrrZNhWQN4Q-zww0RmXOwctKHH9PT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['2019-Dec.csv.gz',\n",
       " '2020-Apr.csv.gz',\n",
       " '2020-Mar.csv.gz',\n",
       " '2020-Feb.csv.gz',\n",
       " '2020-Jan.csv.gz']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import glob  \n",
    "\n",
    "list_files = glob.glob('*.csv.gz')\n",
    "list_files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data extraction and initial preprocessing\n",
    "\n",
    "We extract a few relevant columns from the raw datasets and parse date columns into several atomic columns (day, month...)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                   | 0/5 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2019-Dec.csv.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██████████████▊                                                           | 1/5 [04:16<17:05, 256.45s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-Apr.csv.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|█████████████████████████████▌                                            | 2/5 [08:34<12:51, 257.29s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-Mar.csv.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|████████████████████████████████████████████▍                             | 3/5 [12:02<07:49, 234.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-Feb.csv.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|███████████████████████████████████████████████████████████▏              | 4/5 [15:30<03:44, 224.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-Jan.csv.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████| 5/5 [19:05<00:00, 229.04s/it]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "def process_files(file):\n",
    "    df_tmp = pd.read_csv(file, compression='gzip')\n",
    "    df_tmp['session_purchase'] =  df_tmp['user_session'] + '_' + df_tmp['product_id'].astype(str)\n",
    "    df_purchase = df_tmp[df_tmp['event_type']=='purchase']\n",
    "    df_cart = df_tmp[df_tmp['event_type']=='cart']\n",
    "    df_purchase = df_purchase[df_purchase['session_purchase'].isin(df_cart['session_purchase'])]\n",
    "    df_cart = df_cart[~(df_cart['session_purchase'].isin(df_purchase['session_purchase']))]\n",
    "    df_cart['target'] = 0\n",
    "    df_purchase['target'] = 1\n",
    "    df = pd.concat([df_cart, df_purchase])\n",
    "    df = df.drop('category_id', axis=1)\n",
    "    df = df.drop('session_purchase', axis=1)\n",
    "    df[['cat_0', 'cat_1', 'cat_2', 'cat_3']] = df['category_code'].str.split(\"\\.\", n = 3, expand = True).fillna('NA')\n",
    "    df['brand'] = df['brand'].fillna('NA')\n",
    "    df = df.drop('category_code', axis=1)\n",
    "    df['timestamp'] = pd.to_datetime(df['event_time'].str.replace(' UTC', ''))\n",
    "    df['ts_hour'] = df['timestamp'].dt.hour\n",
    "    df['ts_minute'] = df['timestamp'].dt.minute\n",
    "    df['ts_weekday'] = df['timestamp'].dt.weekday\n",
    "    df['ts_day'] = df['timestamp'].dt.day\n",
    "    df['ts_month'] = df['timestamp'].dt.month\n",
    "    df['ts_year'] = df['timestamp'].dt.year\n",
    "    df.to_csv('./dataset/' + file.replace('.gz', ''), index=False)\n",
    "    \n",
    "!mkdir ./dataset\n",
    "for file in tqdm(list_files):\n",
    "    print(file)\n",
    "    process_files(file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare the training, validation, and test datasets\n",
    "\n",
    "Next, we split the data into training, validation, and test sets. We use 3 months for training, 1 month for validation, and 1 month for testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "lp = []\n",
    "list_files = glob.glob('./dataset/*.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-rw-r--r-- 1 root dip 479323170 Nov 16 22:47 ./dataset/2019-Dec.csv\n",
      "-rw-r--r-- 1 root dip 455992639 Nov 16 22:51 ./dataset/2020-Apr.csv\n",
      "-rw-r--r-- 1 root dip 453967664 Nov 16 22:58 ./dataset/2020-Feb.csv\n",
      "-rw-r--r-- 1 root dip 375205173 Nov 16 23:02 ./dataset/2020-Jan.csv\n",
      "-rw-r--r-- 1 root dip 403896607 Nov 16 22:55 ./dataset/2020-Mar.csv\n"
     ]
    }
   ],
   "source": [
    "!ls -l ./dataset/*.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "for file in list_files:\n",
    "    lp.append(pd.read_csv(file))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(13184044, 19)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.concat(lp)\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test = df[df['ts_month']==4]\n",
    "df_valid = df[df['ts_month']==3]\n",
    "df_train = df[(df['ts_month']!=3)&(df['ts_month']!=4)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((7949839, 19), (2461719, 19), (2772486, 19))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.shape, df_valid.shape, df_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p ./data\n",
    "df_train.to_parquet('./data/train.parquet', index=False)\n",
    "df_valid.to_parquet('./data/valid.parquet', index=False)\n",
    "df_test.to_parquet('./data/test.parquet', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>event_time</th>\n",
       "      <th>event_type</th>\n",
       "      <th>product_id</th>\n",
       "      <th>brand</th>\n",
       "      <th>price</th>\n",
       "      <th>user_id</th>\n",
       "      <th>user_session</th>\n",
       "      <th>target</th>\n",
       "      <th>cat_0</th>\n",
       "      <th>cat_1</th>\n",
       "      <th>cat_2</th>\n",
       "      <th>cat_3</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ts_hour</th>\n",
       "      <th>ts_minute</th>\n",
       "      <th>ts_weekday</th>\n",
       "      <th>ts_day</th>\n",
       "      <th>ts_month</th>\n",
       "      <th>ts_year</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2020-02-01 00:00:18 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>100065078</td>\n",
       "      <td>xiaomi</td>\n",
       "      <td>568.61</td>\n",
       "      <td>526615078</td>\n",
       "      <td>5f0aab9f-f92e-4eff-b0d2-fcec5f553f01</td>\n",
       "      <td>0</td>\n",
       "      <td>construction</td>\n",
       "      <td>tools</td>\n",
       "      <td>light</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2020-02-01 00:00:18</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2020-02-01 00:00:18 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>5701246</td>\n",
       "      <td>NaN</td>\n",
       "      <td>24.43</td>\n",
       "      <td>563902689</td>\n",
       "      <td>76cc9152-8a9f-43e9-b98a-ee484510f379</td>\n",
       "      <td>0</td>\n",
       "      <td>electronics</td>\n",
       "      <td>video</td>\n",
       "      <td>tv</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2020-02-01 00:00:18</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2020-02-01 00:00:31 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>14701533</td>\n",
       "      <td>NaN</td>\n",
       "      <td>154.42</td>\n",
       "      <td>520953435</td>\n",
       "      <td>5f1c7752-cf92-41fc-9a16-e8897a90eee8</td>\n",
       "      <td>0</td>\n",
       "      <td>electronics</td>\n",
       "      <td>video</td>\n",
       "      <td>projector</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2020-02-01 00:00:31</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2020-02-01 00:00:40 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>1004855</td>\n",
       "      <td>xiaomi</td>\n",
       "      <td>123.30</td>\n",
       "      <td>519236281</td>\n",
       "      <td>e512f514-dc7f-4fc9-9042-e3955989d395</td>\n",
       "      <td>0</td>\n",
       "      <td>construction</td>\n",
       "      <td>tools</td>\n",
       "      <td>light</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2020-02-01 00:00:40</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2020-02-01 00:00:47 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>1005100</td>\n",
       "      <td>samsung</td>\n",
       "      <td>140.28</td>\n",
       "      <td>550305600</td>\n",
       "      <td>bd7a37b6-420d-4575-8852-ac825aff39b5</td>\n",
       "      <td>0</td>\n",
       "      <td>construction</td>\n",
       "      <td>tools</td>\n",
       "      <td>light</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2020-02-01 00:00:47</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                event_time event_type  product_id    brand   price    user_id  \\\n",
       "0  2020-02-01 00:00:18 UTC       cart   100065078   xiaomi  568.61  526615078   \n",
       "1  2020-02-01 00:00:18 UTC       cart     5701246      NaN   24.43  563902689   \n",
       "2  2020-02-01 00:00:31 UTC       cart    14701533      NaN  154.42  520953435   \n",
       "3  2020-02-01 00:00:40 UTC       cart     1004855   xiaomi  123.30  519236281   \n",
       "4  2020-02-01 00:00:47 UTC       cart     1005100  samsung  140.28  550305600   \n",
       "\n",
       "                           user_session  target         cat_0  cat_1  \\\n",
       "0  5f0aab9f-f92e-4eff-b0d2-fcec5f553f01       0  construction  tools   \n",
       "1  76cc9152-8a9f-43e9-b98a-ee484510f379       0   electronics  video   \n",
       "2  5f1c7752-cf92-41fc-9a16-e8897a90eee8       0   electronics  video   \n",
       "3  e512f514-dc7f-4fc9-9042-e3955989d395       0  construction  tools   \n",
       "4  bd7a37b6-420d-4575-8852-ac825aff39b5       0  construction  tools   \n",
       "\n",
       "       cat_2 cat_3            timestamp  ts_hour  ts_minute  ts_weekday  \\\n",
       "0      light   NaN  2020-02-01 00:00:18        0          0           5   \n",
       "1         tv   NaN  2020-02-01 00:00:18        0          0           5   \n",
       "2  projector   NaN  2020-02-01 00:00:31        0          0           5   \n",
       "3      light   NaN  2020-02-01 00:00:40        0          0           5   \n",
       "4      light   NaN  2020-02-01 00:00:47        0          0           5   \n",
       "\n",
       "   ts_day  ts_month  ts_year  \n",
       "0       1         2     2020  \n",
       "1       1         2     2020  \n",
       "2       1         2     2020  \n",
       "3       1         2     2020  \n",
       "4       1         2     2020  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing with NVTabular\n",
    "\n",
    "Next, we will use NVTabular for preprocessing and engineering more features. \n",
    "\n",
    "But first, we need to import the necessary libraries and initialize a Dask GPU cluster for computation.\n",
    "\n",
    "### Initialize Dask GPU cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7.1\n"
     ]
    }
   ],
   "source": [
    "# Standard Libraries\n",
    "import os\n",
    "from time import time\n",
    "import re\n",
    "import shutil\n",
    "import glob\n",
    "import warnings\n",
    "\n",
    "# External Dependencies\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import cupy as cp\n",
    "import cudf\n",
    "import dask_cudf\n",
    "from dask_cuda import LocalCUDACluster\n",
    "from dask.distributed import Client\n",
    "from dask.utils import parse_bytes\n",
    "from dask.delayed import delayed\n",
    "import rmm\n",
    "\n",
    "# NVTabular\n",
    "import nvtabular as nvt\n",
    "import nvtabular.ops as ops\n",
    "from nvtabular.io import Shuffle\n",
    "from nvtabular.utils import _pynvml_mem_size, device_mem_size\n",
    "\n",
    "print(nvt.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './nvtabular_temp': No such file or directory\n"
     ]
    }
   ],
   "source": [
    "# define some information about where to get our data\n",
    "BASE_DIR = \"./nvtabular_temp\"\n",
    "!rm -r $BASE_DIR && mkdir $BASE_DIR\n",
    "input_path = './dataset'\n",
    "dask_workdir = os.path.join(BASE_DIR, \"workdir\")\n",
    "output_path = os.path.join(BASE_DIR, \"output\")\n",
    "stats_path = os.path.join(BASE_DIR, \"stats\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example was tested on a DGX server with 8 GPUs. If you have less GPUs, modify the `NUM_GPUS` variable accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "distributed.preloading - INFO - Import preload module: dask_cuda.initialize\n",
      "distributed.preloading - INFO - Import preload module: dask_cuda.initialize\n",
      "distributed.preloading - INFO - Import preload module: dask_cuda.initialize\n",
      "distributed.preloading - INFO - Import preload module: dask_cuda.initialize\n",
      "distributed.preloading - INFO - Import preload module: dask_cuda.initialize\n",
      "distributed.preloading - INFO - Import preload module: dask_cuda.initialize\n",
      "distributed.preloading - INFO - Import preload module: dask_cuda.initialize\n",
      "distributed.preloading - INFO - Import preload module: dask_cuda.initialize\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "            <div>\n",
       "                <div style=\"\n",
       "                    width: 24px;\n",
       "                    height: 24px;\n",
       "                    background-color: #e1e1e1;\n",
       "                    border: 3px solid #9D9D9D;\n",
       "                    border-radius: 5px;\n",
       "                    position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                    <h3 style=\"margin-bottom: 0px;\">Client</h3>\n",
       "                    <p style=\"color: #9D9D9D; margin-bottom: 0px;\">Client-5fa9de34-4731-11ec-81a0-0242c0a88002</p>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                    \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\"><strong>Connection method:</strong> Cluster object</td>\n",
       "                    <td style=\"text-align: left;\"><strong>Cluster type:</strong> LocalCUDACluster</td>\n",
       "                </tr>\n",
       "                \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>Dashboard: </strong>\n",
       "                        <a href=\"http://127.0.0.1:8787/status\">http://127.0.0.1:8787/status</a>\n",
       "                    </td>\n",
       "                    <td style=\"text-align: left;\"></td>\n",
       "                </tr>\n",
       "                \n",
       "                    </table>\n",
       "                    \n",
       "                <details>\n",
       "                <summary style=\"margin-bottom: 20px;\"><h3 style=\"display: inline;\">Cluster Info</h3></summary>\n",
       "                \n",
       "            <div class=\"jp-RenderedHTMLCommon jp-RenderedHTML jp-mod-trusted jp-OutputArea-output\">\n",
       "                <div style=\"\n",
       "                    width: 24px;\n",
       "                    height: 24px;\n",
       "                    background-color: #e1e1e1;\n",
       "                    border: 3px solid #9D9D9D;\n",
       "                    border-radius: 5px;\n",
       "                    position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                    <h3 style=\"margin-bottom: 0px; margin-top: 0px;\">LocalCUDACluster</h3>\n",
       "                    <p style=\"color: #9D9D9D; margin-bottom: 0px;\">280d3dd0</p>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                    \n",
       "            <tr>\n",
       "                <td style=\"text-align: left;\"><strong>Status:</strong> running</td>\n",
       "                <td style=\"text-align: left;\"><strong>Using processes:</strong> True</td>\n",
       "            </tr>\n",
       "        \n",
       "            <tr>\n",
       "                <td style=\"text-align: left;\">\n",
       "                    <strong>Dashboard:</strong> <a href=\"http://127.0.0.1:8787/status\">http://127.0.0.1:8787/status</a>\n",
       "                </td>\n",
       "                <td style=\"text-align: left;\"><strong>Workers:</strong> 8</td>\n",
       "            </tr>\n",
       "            <tr>\n",
       "                <td style=\"text-align: left;\">\n",
       "                    <strong>Total threads:</strong>\n",
       "                    8\n",
       "                </td>\n",
       "                <td style=\"text-align: left;\">\n",
       "                    <strong>Total memory:</strong>\n",
       "                    503.81 GiB\n",
       "                </td>\n",
       "            </tr>\n",
       "        \n",
       "                    </table>\n",
       "                    <details>\n",
       "                    <summary style=\"margin-bottom: 20px;\"><h3 style=\"display: inline;\">Scheduler Info</h3></summary>\n",
       "                    \n",
       "        <div style=\"\">\n",
       "            \n",
       "            <div>\n",
       "                <div style=\"\n",
       "                    width: 24px;\n",
       "                    height: 24px;\n",
       "                    background-color: #FFF7E5;\n",
       "                    border: 3px solid #FF6132;\n",
       "                    border-radius: 5px;\n",
       "                    position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                    <h3 style=\"margin-bottom: 0px;\">Scheduler</h3>\n",
       "                    <p style=\"color: #9D9D9D; margin-bottom: 0px;\">Scheduler-e2967024-5ba6-46fa-889b-ce05595cfe32</p>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Comm:</strong> tcp://127.0.0.1:42281</td>\n",
       "                            <td style=\"text-align: left;\"><strong>Workers:</strong> 8</td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Dashboard:</strong> <a href=\"http://127.0.0.1:8787/status\">http://127.0.0.1:8787/status</a>\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Total threads:</strong>\n",
       "                                8\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Started:</strong>\n",
       "                                Just now\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Total memory:</strong>\n",
       "                                503.81 GiB\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                    </table>\n",
       "                </div>\n",
       "            </div>\n",
       "        \n",
       "            <details style=\"margin-left: 48px;\">\n",
       "            <summary style=\"margin-bottom: 20px;\"><h3 style=\"display: inline;\">Workers</h3></summary>\n",
       "            \n",
       "            <div style=\"margin-bottom: 20px;\">\n",
       "                <div style=\"width: 24px;\n",
       "                            height: 24px;\n",
       "                            background-color: #DBF5FF;\n",
       "                            border: 3px solid #4CC9FF;\n",
       "                            border-radius: 5px;\n",
       "                            position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                <details>\n",
       "                    <summary>\n",
       "                        <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 0</h4>\n",
       "                    </summary>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Comm: </strong> tcp://127.0.0.1:45134</td>\n",
       "                            <td style=\"text-align: left;\"><strong>Total threads: </strong> 1</td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Dashboard: </strong>\n",
       "                                <a href=\"http://127.0.0.1:45519/status\">http://127.0.0.1:45519/status</a>\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Memory: </strong>\n",
       "                                62.98 GiB\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Nanny: </strong> tcp://127.0.0.1:40585</td>\n",
       "                            <td style=\"text-align: left;\"></td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td colspan=\"2\" style=\"text-align: left;\">\n",
       "                                <strong>Local directory: </strong>\n",
       "                                /hugectr/notebooks/nvtabular_temp/workdir/dask-worker-space/worker-m2ipvyh1\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU: </strong>Tesla P100-SXM2-16GB\n",
       "                    </td>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU memory: </strong>\n",
       "                        15.90 GiB\n",
       "                    </td>\n",
       "                </tr>\n",
       "                \n",
       "                        \n",
       "                    </table>\n",
       "                </details>\n",
       "                </div>\n",
       "            </div>\n",
       "            \n",
       "            <div style=\"margin-bottom: 20px;\">\n",
       "                <div style=\"width: 24px;\n",
       "                            height: 24px;\n",
       "                            background-color: #DBF5FF;\n",
       "                            border: 3px solid #4CC9FF;\n",
       "                            border-radius: 5px;\n",
       "                            position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                <details>\n",
       "                    <summary>\n",
       "                        <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 1</h4>\n",
       "                    </summary>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Comm: </strong> tcp://127.0.0.1:45442</td>\n",
       "                            <td style=\"text-align: left;\"><strong>Total threads: </strong> 1</td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Dashboard: </strong>\n",
       "                                <a href=\"http://127.0.0.1:36563/status\">http://127.0.0.1:36563/status</a>\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Memory: </strong>\n",
       "                                62.98 GiB\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Nanny: </strong> tcp://127.0.0.1:33615</td>\n",
       "                            <td style=\"text-align: left;\"></td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td colspan=\"2\" style=\"text-align: left;\">\n",
       "                                <strong>Local directory: </strong>\n",
       "                                /hugectr/notebooks/nvtabular_temp/workdir/dask-worker-space/worker-hhdw5hgq\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU: </strong>Tesla P100-SXM2-16GB\n",
       "                    </td>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU memory: </strong>\n",
       "                        15.90 GiB\n",
       "                    </td>\n",
       "                </tr>\n",
       "                \n",
       "                        \n",
       "                    </table>\n",
       "                </details>\n",
       "                </div>\n",
       "            </div>\n",
       "            \n",
       "            <div style=\"margin-bottom: 20px;\">\n",
       "                <div style=\"width: 24px;\n",
       "                            height: 24px;\n",
       "                            background-color: #DBF5FF;\n",
       "                            border: 3px solid #4CC9FF;\n",
       "                            border-radius: 5px;\n",
       "                            position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                <details>\n",
       "                    <summary>\n",
       "                        <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 2</h4>\n",
       "                    </summary>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Comm: </strong> tcp://127.0.0.1:38832</td>\n",
       "                            <td style=\"text-align: left;\"><strong>Total threads: </strong> 1</td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Dashboard: </strong>\n",
       "                                <a href=\"http://127.0.0.1:45537/status\">http://127.0.0.1:45537/status</a>\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Memory: </strong>\n",
       "                                62.98 GiB\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Nanny: </strong> tcp://127.0.0.1:34857</td>\n",
       "                            <td style=\"text-align: left;\"></td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td colspan=\"2\" style=\"text-align: left;\">\n",
       "                                <strong>Local directory: </strong>\n",
       "                                /hugectr/notebooks/nvtabular_temp/workdir/dask-worker-space/worker-3xv1lvg4\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU: </strong>Tesla P100-SXM2-16GB\n",
       "                    </td>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU memory: </strong>\n",
       "                        15.90 GiB\n",
       "                    </td>\n",
       "                </tr>\n",
       "                \n",
       "                        \n",
       "                    </table>\n",
       "                </details>\n",
       "                </div>\n",
       "            </div>\n",
       "            \n",
       "            <div style=\"margin-bottom: 20px;\">\n",
       "                <div style=\"width: 24px;\n",
       "                            height: 24px;\n",
       "                            background-color: #DBF5FF;\n",
       "                            border: 3px solid #4CC9FF;\n",
       "                            border-radius: 5px;\n",
       "                            position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                <details>\n",
       "                    <summary>\n",
       "                        <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 3</h4>\n",
       "                    </summary>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Comm: </strong> tcp://127.0.0.1:33468</td>\n",
       "                            <td style=\"text-align: left;\"><strong>Total threads: </strong> 1</td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Dashboard: </strong>\n",
       "                                <a href=\"http://127.0.0.1:34645/status\">http://127.0.0.1:34645/status</a>\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Memory: </strong>\n",
       "                                62.98 GiB\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Nanny: </strong> tcp://127.0.0.1:33516</td>\n",
       "                            <td style=\"text-align: left;\"></td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td colspan=\"2\" style=\"text-align: left;\">\n",
       "                                <strong>Local directory: </strong>\n",
       "                                /hugectr/notebooks/nvtabular_temp/workdir/dask-worker-space/worker-_v4e6b68\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU: </strong>Tesla P100-SXM2-16GB\n",
       "                    </td>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU memory: </strong>\n",
       "                        15.90 GiB\n",
       "                    </td>\n",
       "                </tr>\n",
       "                \n",
       "                        \n",
       "                    </table>\n",
       "                </details>\n",
       "                </div>\n",
       "            </div>\n",
       "            \n",
       "            <div style=\"margin-bottom: 20px;\">\n",
       "                <div style=\"width: 24px;\n",
       "                            height: 24px;\n",
       "                            background-color: #DBF5FF;\n",
       "                            border: 3px solid #4CC9FF;\n",
       "                            border-radius: 5px;\n",
       "                            position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                <details>\n",
       "                    <summary>\n",
       "                        <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 4</h4>\n",
       "                    </summary>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Comm: </strong> tcp://127.0.0.1:38052</td>\n",
       "                            <td style=\"text-align: left;\"><strong>Total threads: </strong> 1</td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Dashboard: </strong>\n",
       "                                <a href=\"http://127.0.0.1:33234/status\">http://127.0.0.1:33234/status</a>\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Memory: </strong>\n",
       "                                62.98 GiB\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Nanny: </strong> tcp://127.0.0.1:42282</td>\n",
       "                            <td style=\"text-align: left;\"></td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td colspan=\"2\" style=\"text-align: left;\">\n",
       "                                <strong>Local directory: </strong>\n",
       "                                /hugectr/notebooks/nvtabular_temp/workdir/dask-worker-space/worker-lbkl0_sc\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU: </strong>Tesla P100-SXM2-16GB\n",
       "                    </td>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU memory: </strong>\n",
       "                        15.90 GiB\n",
       "                    </td>\n",
       "                </tr>\n",
       "                \n",
       "                        \n",
       "                    </table>\n",
       "                </details>\n",
       "                </div>\n",
       "            </div>\n",
       "            \n",
       "            <div style=\"margin-bottom: 20px;\">\n",
       "                <div style=\"width: 24px;\n",
       "                            height: 24px;\n",
       "                            background-color: #DBF5FF;\n",
       "                            border: 3px solid #4CC9FF;\n",
       "                            border-radius: 5px;\n",
       "                            position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                <details>\n",
       "                    <summary>\n",
       "                        <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 5</h4>\n",
       "                    </summary>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Comm: </strong> tcp://127.0.0.1:46068</td>\n",
       "                            <td style=\"text-align: left;\"><strong>Total threads: </strong> 1</td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Dashboard: </strong>\n",
       "                                <a href=\"http://127.0.0.1:34702/status\">http://127.0.0.1:34702/status</a>\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Memory: </strong>\n",
       "                                62.98 GiB\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Nanny: </strong> tcp://127.0.0.1:39148</td>\n",
       "                            <td style=\"text-align: left;\"></td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td colspan=\"2\" style=\"text-align: left;\">\n",
       "                                <strong>Local directory: </strong>\n",
       "                                /hugectr/notebooks/nvtabular_temp/workdir/dask-worker-space/worker-47f2dcfj\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU: </strong>Tesla P100-SXM2-16GB\n",
       "                    </td>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU memory: </strong>\n",
       "                        15.90 GiB\n",
       "                    </td>\n",
       "                </tr>\n",
       "                \n",
       "                        \n",
       "                    </table>\n",
       "                </details>\n",
       "                </div>\n",
       "            </div>\n",
       "            \n",
       "            <div style=\"margin-bottom: 20px;\">\n",
       "                <div style=\"width: 24px;\n",
       "                            height: 24px;\n",
       "                            background-color: #DBF5FF;\n",
       "                            border: 3px solid #4CC9FF;\n",
       "                            border-radius: 5px;\n",
       "                            position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                <details>\n",
       "                    <summary>\n",
       "                        <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 6</h4>\n",
       "                    </summary>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Comm: </strong> tcp://127.0.0.1:41440</td>\n",
       "                            <td style=\"text-align: left;\"><strong>Total threads: </strong> 1</td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Dashboard: </strong>\n",
       "                                <a href=\"http://127.0.0.1:33288/status\">http://127.0.0.1:33288/status</a>\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Memory: </strong>\n",
       "                                62.98 GiB\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Nanny: </strong> tcp://127.0.0.1:36099</td>\n",
       "                            <td style=\"text-align: left;\"></td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td colspan=\"2\" style=\"text-align: left;\">\n",
       "                                <strong>Local directory: </strong>\n",
       "                                /hugectr/notebooks/nvtabular_temp/workdir/dask-worker-space/worker-qq2t20ws\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU: </strong>Tesla P100-SXM2-16GB\n",
       "                    </td>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU memory: </strong>\n",
       "                        15.90 GiB\n",
       "                    </td>\n",
       "                </tr>\n",
       "                \n",
       "                        \n",
       "                    </table>\n",
       "                </details>\n",
       "                </div>\n",
       "            </div>\n",
       "            \n",
       "            <div style=\"margin-bottom: 20px;\">\n",
       "                <div style=\"width: 24px;\n",
       "                            height: 24px;\n",
       "                            background-color: #DBF5FF;\n",
       "                            border: 3px solid #4CC9FF;\n",
       "                            border-radius: 5px;\n",
       "                            position: absolute;\"> </div>\n",
       "                <div style=\"margin-left: 48px;\">\n",
       "                <details>\n",
       "                    <summary>\n",
       "                        <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 7</h4>\n",
       "                    </summary>\n",
       "                    <table style=\"width: 100%; text-align: left;\">\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Comm: </strong> tcp://127.0.0.1:43583</td>\n",
       "                            <td style=\"text-align: left;\"><strong>Total threads: </strong> 1</td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Dashboard: </strong>\n",
       "                                <a href=\"http://127.0.0.1:45175/status\">http://127.0.0.1:45175/status</a>\n",
       "                            </td>\n",
       "                            <td style=\"text-align: left;\">\n",
       "                                <strong>Memory: </strong>\n",
       "                                62.98 GiB\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td style=\"text-align: left;\"><strong>Nanny: </strong> tcp://127.0.0.1:35315</td>\n",
       "                            <td style=\"text-align: left;\"></td>\n",
       "                        </tr>\n",
       "                        <tr>\n",
       "                            <td colspan=\"2\" style=\"text-align: left;\">\n",
       "                                <strong>Local directory: </strong>\n",
       "                                /hugectr/notebooks/nvtabular_temp/workdir/dask-worker-space/worker-1v3qmqn5\n",
       "                            </td>\n",
       "                        </tr>\n",
       "                        \n",
       "                <tr>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU: </strong>Tesla P100-SXM2-16GB\n",
       "                    </td>\n",
       "                    <td style=\"text-align: left;\">\n",
       "                        <strong>GPU memory: </strong>\n",
       "                        15.90 GiB\n",
       "                    </td>\n",
       "                </tr>\n",
       "                \n",
       "                        \n",
       "                    </table>\n",
       "                </details>\n",
       "                </div>\n",
       "            </div>\n",
       "            \n",
       "            </details>\n",
       "        </div>\n",
       "        \n",
       "                    </details>\n",
       "                </div>\n",
       "            </div>\n",
       "        \n",
       "                </details>\n",
       "                \n",
       "                </div>\n",
       "            </div>\n",
       "        "
      ],
      "text/plain": [
       "<Client: 'tcp://127.0.0.1:42281' processes=8 threads=8, memory=503.81 GiB>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NUM_GPUS = [0,1,2,3,4,5,6,7]\n",
    "#NUM_GPUS = [0]\n",
    "\n",
    "# Dask dashboard\n",
    "dashboard_port = \"8787\"\n",
    "\n",
    "# Deploy a Single-Machine Multi-GPU Cluster\n",
    "protocol = \"tcp\"             # \"tcp\" or \"ucx\"\n",
    "visible_devices = \",\".join([str(n) for n in NUM_GPUS])  # Delect devices to place workers\n",
    "device_limit_frac = 0.5      # Spill GPU-Worker memory to host at this limit.\n",
    "device_pool_frac = 0.6\n",
    "part_mem_frac = 0.05\n",
    "\n",
    "# Use total device size to calculate args.device_limit_frac\n",
    "device_size = device_mem_size(kind=\"total\")\n",
    "device_limit = int(device_limit_frac * device_size)\n",
    "device_pool_size = int(device_pool_frac * device_size)\n",
    "part_size = int(part_mem_frac * device_size)\n",
    "\n",
    "# Check if any device memory is already occupied\n",
    "\"\"\"\n",
    "for dev in visible_devices.split(\",\"):\n",
    "    fmem = _pynvml_mem_size(kind=\"free\", index=int(dev))\n",
    "    used = (device_size - fmem) / 1e9\n",
    "    if used > 1.0:\n",
    "        warnings.warn(f\"BEWARE - {used} GB is already occupied on device {int(dev)}!\")\n",
    "\"\"\"\n",
    "\n",
    "cluster = None               # (Optional) Specify existing scheduler port\n",
    "if cluster is None:\n",
    "    cluster = LocalCUDACluster(\n",
    "        protocol = protocol,\n",
    "        n_workers=len(visible_devices.split(\",\")),\n",
    "        CUDA_VISIBLE_DEVICES = visible_devices,\n",
    "        device_memory_limit = device_limit,\n",
    "        local_directory=dask_workdir,\n",
    "        dashboard_address=\":\" + dashboard_port,\n",
    "    )\n",
    "\n",
    "# Create the distributed client\n",
    "client = Client(cluster)\n",
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tue Nov 16 23:03:13 2021       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 450.51.06    Driver Version: 450.51.06    CUDA Version: 11.4     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla P100-SXM2...  On   | 00000000:06:00.0 Off |                    0 |\n",
      "| N/A   41C    P0    44W / 300W |    508MiB / 16280MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  Tesla P100-SXM2...  On   | 00000000:07:00.0 Off |                    0 |\n",
      "| N/A   38C    P0    41W / 300W |    255MiB / 16280MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  Tesla P100-SXM2...  On   | 00000000:0A:00.0 Off |                    0 |\n",
      "| N/A   38C    P0    44W / 300W |    255MiB / 16280MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   3  Tesla P100-SXM2...  On   | 00000000:0B:00.0 Off |                    0 |\n",
      "| N/A   39C    P0    45W / 300W |    255MiB / 16280MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   4  Tesla P100-SXM2...  On   | 00000000:85:00.0 Off |                    0 |\n",
      "| N/A   43C    P0    43W / 300W |    255MiB / 16280MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   5  Tesla P100-SXM2...  On   | 00000000:86:00.0 Off |                    0 |\n",
      "| N/A   44C    P0    44W / 300W |    255MiB / 16280MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   6  Tesla P100-SXM2...  On   | 00000000:89:00.0 Off |                    0 |\n",
      "| N/A   39C    P0    44W / 300W |    255MiB / 16280MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   7  Tesla P100-SXM2...  On   | 00000000:8A:00.0 Off |                    0 |\n",
      "| N/A   38C    P0    42W / 300W |    255MiB / 16280MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'tcp://127.0.0.1:33468': None,\n",
       " 'tcp://127.0.0.1:38052': None,\n",
       " 'tcp://127.0.0.1:38832': None,\n",
       " 'tcp://127.0.0.1:41440': None,\n",
       " 'tcp://127.0.0.1:43583': None,\n",
       " 'tcp://127.0.0.1:45134': None,\n",
       " 'tcp://127.0.0.1:45442': None,\n",
       " 'tcp://127.0.0.1:46068': None}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Initialize RMM pool on ALL workers\n",
    "def _rmm_pool():\n",
    "    rmm.reinitialize(\n",
    "        # RMM may require the pool size to be a multiple of 256.\n",
    "        pool_allocator=True,\n",
    "        initial_pool_size=(device_pool_size // 256) * 256, # Use default size\n",
    "    )\n",
    "    \n",
    "client.run(_rmm_pool)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define NVTabular dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_paths = glob.glob('./data/train.parquet')\n",
    "valid_paths = glob.glob('./data/valid.parquet')\n",
    "test_paths = glob.glob('./data/test.parquet')\n",
    "\n",
    "train_dataset = nvt.Dataset(train_paths, engine='parquet', part_mem_fraction=0.15)\n",
    "valid_dataset = nvt.Dataset(valid_paths, engine='parquet', part_mem_fraction=0.15)\n",
    "test_dataset = nvt.Dataset(test_paths, engine='parquet', part_mem_fraction=0.15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>event_time</th>\n",
       "      <th>event_type</th>\n",
       "      <th>product_id</th>\n",
       "      <th>brand</th>\n",
       "      <th>price</th>\n",
       "      <th>user_id</th>\n",
       "      <th>user_session</th>\n",
       "      <th>target</th>\n",
       "      <th>cat_0</th>\n",
       "      <th>cat_1</th>\n",
       "      <th>cat_2</th>\n",
       "      <th>cat_3</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ts_hour</th>\n",
       "      <th>ts_minute</th>\n",
       "      <th>ts_weekday</th>\n",
       "      <th>ts_day</th>\n",
       "      <th>ts_month</th>\n",
       "      <th>ts_year</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2020-02-01 00:00:18 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>100065078</td>\n",
       "      <td>xiaomi</td>\n",
       "      <td>568.61</td>\n",
       "      <td>526615078</td>\n",
       "      <td>5f0aab9f-f92e-4eff-b0d2-fcec5f553f01</td>\n",
       "      <td>0</td>\n",
       "      <td>construction</td>\n",
       "      <td>tools</td>\n",
       "      <td>light</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>2020-02-01 00:00:18</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2020-02-01 00:00:18 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>5701246</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>24.43</td>\n",
       "      <td>563902689</td>\n",
       "      <td>76cc9152-8a9f-43e9-b98a-ee484510f379</td>\n",
       "      <td>0</td>\n",
       "      <td>electronics</td>\n",
       "      <td>video</td>\n",
       "      <td>tv</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>2020-02-01 00:00:18</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2020-02-01 00:00:31 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>14701533</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>154.42</td>\n",
       "      <td>520953435</td>\n",
       "      <td>5f1c7752-cf92-41fc-9a16-e8897a90eee8</td>\n",
       "      <td>0</td>\n",
       "      <td>electronics</td>\n",
       "      <td>video</td>\n",
       "      <td>projector</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>2020-02-01 00:00:31</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2020-02-01 00:00:40 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>1004855</td>\n",
       "      <td>xiaomi</td>\n",
       "      <td>123.30</td>\n",
       "      <td>519236281</td>\n",
       "      <td>e512f514-dc7f-4fc9-9042-e3955989d395</td>\n",
       "      <td>0</td>\n",
       "      <td>construction</td>\n",
       "      <td>tools</td>\n",
       "      <td>light</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>2020-02-01 00:00:40</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2020-02-01 00:00:47 UTC</td>\n",
       "      <td>cart</td>\n",
       "      <td>1005100</td>\n",
       "      <td>samsung</td>\n",
       "      <td>140.28</td>\n",
       "      <td>550305600</td>\n",
       "      <td>bd7a37b6-420d-4575-8852-ac825aff39b5</td>\n",
       "      <td>0</td>\n",
       "      <td>construction</td>\n",
       "      <td>tools</td>\n",
       "      <td>light</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>2020-02-01 00:00:47</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2020</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                event_time event_type  product_id    brand   price    user_id  \\\n",
       "0  2020-02-01 00:00:18 UTC       cart   100065078   xiaomi  568.61  526615078   \n",
       "1  2020-02-01 00:00:18 UTC       cart     5701246     <NA>   24.43  563902689   \n",
       "2  2020-02-01 00:00:31 UTC       cart    14701533     <NA>  154.42  520953435   \n",
       "3  2020-02-01 00:00:40 UTC       cart     1004855   xiaomi  123.30  519236281   \n",
       "4  2020-02-01 00:00:47 UTC       cart     1005100  samsung  140.28  550305600   \n",
       "\n",
       "                           user_session  target         cat_0  cat_1  \\\n",
       "0  5f0aab9f-f92e-4eff-b0d2-fcec5f553f01       0  construction  tools   \n",
       "1  76cc9152-8a9f-43e9-b98a-ee484510f379       0   electronics  video   \n",
       "2  5f1c7752-cf92-41fc-9a16-e8897a90eee8       0   electronics  video   \n",
       "3  e512f514-dc7f-4fc9-9042-e3955989d395       0  construction  tools   \n",
       "4  bd7a37b6-420d-4575-8852-ac825aff39b5       0  construction  tools   \n",
       "\n",
       "       cat_2 cat_3            timestamp  ts_hour  ts_minute  ts_weekday  \\\n",
       "0      light  <NA>  2020-02-01 00:00:18        0          0           5   \n",
       "1         tv  <NA>  2020-02-01 00:00:18        0          0           5   \n",
       "2  projector  <NA>  2020-02-01 00:00:31        0          0           5   \n",
       "3      light  <NA>  2020-02-01 00:00:40        0          0           5   \n",
       "4      light  <NA>  2020-02-01 00:00:47        0          0           5   \n",
       "\n",
       "   ts_day  ts_month  ts_year  \n",
       "0       1         2     2020  \n",
       "1       1         2     2020  \n",
       "2       1         2     2020  \n",
       "3       1         2     2020  \n",
       "4       1         2     2020  "
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.to_ddf().head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "19"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataset.to_ddf().columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['event_time', 'event_type', 'product_id', 'brand', 'price', 'user_id',\n",
       "       'user_session', 'target', 'cat_0', 'cat_1', 'cat_2', 'cat_3',\n",
       "       'timestamp', 'ts_hour', 'ts_minute', 'ts_weekday', 'ts_day', 'ts_month',\n",
       "       'ts_year'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.to_ddf().columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7949839"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataset.to_ddf())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing and feature engineering\n",
    "\n",
    "In this notebook we will explore a few feature engineering technique with NVTabular:\n",
    "\n",
    "- Creating cross features, e.g. `user_id` and `'brand`\n",
    "- Target encoding\n",
    "\n",
    "The engineered features will then be preprocessed into a form suitable for machine learning model:\n",
    "\n",
    "- Fill missing values\n",
    "- Encoding categorical features into integer values\n",
    "- Normalization of numeric features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nvtabular.ops import LambdaOp\n",
    "\n",
    "# cross features\n",
    "def user_id_cross_maker(col, gdf):\n",
    "    return col.astype(str) + '_' + gdf['user_id'].astype(str)\n",
    "\n",
    "user_id_cross_features = (\n",
    "    nvt.ColumnGroup(['product_id', 'brand', 'ts_hour', 'ts_minute']) >>\n",
    "    LambdaOp(user_id_cross_maker, dependency=['user_id']) >> \n",
    "    nvt.ops.Rename(postfix = '_user_id_cross')\n",
    ")\n",
    "\n",
    "\n",
    "def user_id_brand_cross_maker(col, gdf):\n",
    "    return col.astype(str) + '_' + gdf['user_id'].astype(str) + '_' + gdf['brand'].astype(str)\n",
    "\n",
    "user_id_brand_cross_features = (\n",
    "    nvt.ColumnGroup(['ts_hour', 'ts_weekday', 'cat_0', 'cat_1', 'cat_2']) >>\n",
    "    LambdaOp(user_id_brand_cross_maker, dependency=['user_id', 'brand']) >> \n",
    "    nvt.ops.Rename(postfix = '_user_id_brand_cross')\n",
    ")\n",
    "\n",
    "target_encode = (\n",
    "    ['brand', 'user_id', 'product_id', 'cat_2', ['ts_weekday', 'ts_day']] >>\n",
    "    nvt.ops.TargetEncoding(\n",
    "        nvt.ColumnGroup('target'),\n",
    "        kfold=5,\n",
    "        p_smooth=20,\n",
    "        out_dtype=\"float32\",\n",
    "        )\n",
    ")\n",
    "\n",
    "cat_feats = (user_id_brand_cross_features + user_id_cross_features) >> nvt.ops.Categorify()\n",
    "cont_feats =  ['price', 'ts_weekday', 'ts_day', 'ts_month'] >> nvt.ops.FillMissing() >>  nvt.ops.Normalize()\n",
    "cont_feats += target_encode >> nvt.ops.Rename(postfix = '_TE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/nvtabular/nvtabular/workflow/workflow.py:89: UserWarning: A global dask.distributed client has been detected, but the single-threaded scheduler will be used for execution. Please use the `client` argument to initialize a `Workflow` object with distributed-execution enabled.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "output = cat_feats + cont_feats + 'target'\n",
    "proc = nvt.Workflow(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize workflow as a DAG\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt install -y graphviz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.43.0 (0)\n -->\n<!-- Title: %3 Pages: 1 -->\n<svg width=\"2795pt\" height=\"548pt\"\n viewBox=\"0.00 0.00 2794.72 548.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 544)\">\n<title>%3</title>\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-544 2790.72,-544 2790.72,4 -4,4\"/>\n<!-- 0 -->\n<g id=\"node1\" class=\"node\">\n<title>0</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1956.94\" cy=\"-162\" rx=\"48.19\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1956.94\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">Rename</text>\n</g>\n<!-- 4 -->\n<g id=\"node9\" class=\"node\">\n<title>4</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1594.94\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1594.94\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n</g>\n<!-- 0&#45;&gt;4 -->\n<g id=\"edge7\" class=\"edge\">\n<title>0&#45;&gt;4</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1915.06,-152.9C1843.34,-139.03 1698.22,-110.97 1630.83,-97.94\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1631.19,-94.44 1620.71,-95.98 1629.86,-101.32 1631.19,-94.44\"/>\n</g>\n<!-- 7 -->\n<g id=\"node2\" class=\"node\">\n<title>7</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"2037.94\" cy=\"-234\" rx=\"84.49\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"2037.94\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">TargetEncoding</text>\n</g>\n<!-- 7&#45;&gt;0 -->\n<g id=\"edge1\" class=\"edge\">\n<title>7&#45;&gt;0</title>\n<path fill=\"none\" stroke=\"black\" d=\"M2018.75,-216.41C2008.03,-207.15 1994.53,-195.48 1982.88,-185.41\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1985.02,-182.64 1975.17,-178.75 1980.45,-187.94 1985.02,-182.64\"/>\n</g>\n<!-- 1 -->\n<g id=\"node3\" class=\"node\">\n<title>1</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"66.94\" cy=\"-450\" rx=\"66.89\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"66.94\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\">SelectionOp</text>\n</g>\n<!-- 17 -->\n<g id=\"node8\" class=\"node\">\n<title>17</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"370.94\" cy=\"-378\" rx=\"138.38\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"370.94\" y=\"-374.3\" font-family=\"Times,serif\" font-size=\"14.00\">user_id_brand_cross_maker</text>\n</g>\n<!-- 1&#45;&gt;17 -->\n<g id=\"edge27\" class=\"edge\">\n<title>1&#45;&gt;17</title>\n<path fill=\"none\" stroke=\"black\" d=\"M116.12,-437.68C164.79,-426.47 239.93,-409.17 296.03,-396.25\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"296.96,-399.63 305.92,-393.97 295.39,-392.81 296.96,-399.63\"/>\n</g>\n<!-- 1_selector -->\n<g id=\"node4\" class=\"node\">\n<title>1_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"66.94\" cy=\"-522\" rx=\"55.79\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"66.94\" y=\"-518.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;user_id&#39;]</text>\n</g>\n<!-- 1_selector&#45;&gt;1 -->\n<g id=\"edge2\" class=\"edge\">\n<title>1_selector&#45;&gt;1</title>\n<path fill=\"none\" stroke=\"black\" d=\"M66.94,-503.7C66.94,-495.98 66.94,-486.71 66.94,-478.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"70.44,-478.1 66.94,-468.1 63.44,-478.1 70.44,-478.1\"/>\n</g>\n<!-- 2 -->\n<g id=\"node5\" class=\"node\">\n<title>2</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"218.94\" cy=\"-450\" rx=\"66.89\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"218.94\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\">SelectionOp</text>\n</g>\n<!-- 2&#45;&gt;17 -->\n<g id=\"edge28\" class=\"edge\">\n<title>2&#45;&gt;17</title>\n<path fill=\"none\" stroke=\"black\" d=\"M251.18,-434.15C272.95,-424.13 301.99,-410.75 326.02,-399.69\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"327.68,-402.78 335.3,-395.41 324.75,-396.42 327.68,-402.78\"/>\n</g>\n<!-- 2_selector -->\n<g id=\"node6\" class=\"node\">\n<title>2_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"204.94\" cy=\"-522\" rx=\"50.09\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"204.94\" y=\"-518.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;brand&#39;]</text>\n</g>\n<!-- 2_selector&#45;&gt;2 -->\n<g id=\"edge3\" class=\"edge\">\n<title>2_selector&#45;&gt;2</title>\n<path fill=\"none\" stroke=\"black\" d=\"M208.4,-503.7C209.95,-495.98 211.8,-486.71 213.52,-478.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"216.99,-478.6 215.52,-468.1 210.13,-477.22 216.99,-478.6\"/>\n</g>\n<!-- 3 -->\n<g id=\"node7\" class=\"node\">\n<title>3</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"701.94\" cy=\"-306\" rx=\"48.19\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"701.94\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\">Rename</text>\n</g>\n<!-- 5 -->\n<g id=\"node13\" class=\"node\">\n<title>5</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1070.94\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1070.94\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n</g>\n<!-- 3&#45;&gt;5 -->\n<g id=\"edge9\" class=\"edge\">\n<title>3&#45;&gt;5</title>\n<path fill=\"none\" stroke=\"black\" d=\"M743.91,-297.04C817.04,-283.17 966.55,-254.8 1035.12,-241.8\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1035.85,-245.22 1045.02,-239.92 1034.55,-238.34 1035.85,-245.22\"/>\n</g>\n<!-- 17&#45;&gt;3 -->\n<g id=\"edge4\" class=\"edge\">\n<title>17&#45;&gt;3</title>\n<path fill=\"none\" stroke=\"black\" d=\"M440.33,-362.33C503.15,-349.04 594.24,-329.78 650.69,-317.84\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"651.76,-321.19 660.82,-315.7 650.31,-314.34 651.76,-321.19\"/>\n</g>\n<!-- 20 -->\n<g id=\"node34\" class=\"node\">\n<title>20</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1594.94\" cy=\"-18\" rx=\"62.29\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1594.94\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">output cols</text>\n</g>\n<!-- 4&#45;&gt;20 -->\n<g id=\"edge33\" class=\"edge\">\n<title>4&#45;&gt;20</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1594.94,-71.7C1594.94,-63.98 1594.94,-54.71 1594.94,-46.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1598.44,-46.1 1594.94,-36.1 1591.44,-46.1 1598.44,-46.1\"/>\n</g>\n<!-- 11 -->\n<g id=\"node10\" class=\"node\">\n<title>11</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1306.94\" cy=\"-162\" rx=\"59.59\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1306.94\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">Categorify</text>\n</g>\n<!-- 11&#45;&gt;4 -->\n<g id=\"edge5\" class=\"edge\">\n<title>11&#45;&gt;4</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1351.94,-150.06C1409.49,-136.08 1507.92,-112.15 1560.1,-99.47\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1561,-102.85 1569.89,-97.09 1559.35,-96.05 1561,-102.85\"/>\n</g>\n<!-- 19 -->\n<g id=\"node11\" class=\"node\">\n<title>19</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1523.94\" cy=\"-162\" rx=\"58.49\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1523.94\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">Normalize</text>\n</g>\n<!-- 19&#45;&gt;4 -->\n<g id=\"edge6\" class=\"edge\">\n<title>19&#45;&gt;4</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1540.77,-144.41C1550.41,-134.91 1562.61,-122.88 1572.99,-112.65\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1575.5,-115.08 1580.17,-105.57 1570.59,-110.1 1575.5,-115.08\"/>\n</g>\n<!-- 13 -->\n<g id=\"node12\" class=\"node\">\n<title>13</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1666.94\" cy=\"-162\" rx=\"66.89\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1666.94\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">SelectionOp</text>\n</g>\n<!-- 13&#45;&gt;4 -->\n<g id=\"edge8\" class=\"edge\">\n<title>13&#45;&gt;4</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1649.88,-144.41C1640.11,-134.91 1627.74,-122.88 1617.21,-112.65\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1619.54,-110.03 1609.93,-105.57 1614.66,-115.05 1619.54,-110.03\"/>\n</g>\n<!-- 5&#45;&gt;11 -->\n<g id=\"edge20\" class=\"edge\">\n<title>5&#45;&gt;11</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1095.31,-225.77C1133.07,-214.57 1206.53,-192.78 1256.17,-178.06\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1257.17,-181.41 1265.76,-175.21 1255.18,-174.7 1257.17,-181.41\"/>\n</g>\n<!-- 15 -->\n<g id=\"node14\" class=\"node\">\n<title>15</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1070.94\" cy=\"-306\" rx=\"48.19\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1070.94\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\">Rename</text>\n</g>\n<!-- 15&#45;&gt;5 -->\n<g id=\"edge10\" class=\"edge\">\n<title>15&#45;&gt;5</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1070.94,-287.7C1070.94,-279.98 1070.94,-270.71 1070.94,-262.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1074.44,-262.1 1070.94,-252.1 1067.44,-262.1 1074.44,-262.1\"/>\n</g>\n<!-- 6 -->\n<g id=\"node15\" class=\"node\">\n<title>6</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1961.94\" cy=\"-306\" rx=\"66.89\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1961.94\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\">SelectionOp</text>\n</g>\n<!-- 6&#45;&gt;7 -->\n<g id=\"edge12\" class=\"edge\">\n<title>6&#45;&gt;7</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1979.95,-288.41C1989.63,-279.5 2001.73,-268.36 2012.37,-258.56\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"2014.81,-261.06 2019.8,-251.71 2010.07,-255.91 2014.81,-261.06\"/>\n</g>\n<!-- 6_selector -->\n<g id=\"node16\" class=\"node\">\n<title>6_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1946.94\" cy=\"-378\" rx=\"293.55\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1946.94\" y=\"-374.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;brand&#39;, &#39;user_id&#39;, &#39;product_id&#39;, &#39;cat_2&#39;, &#39;ts_weekday&#39;, &#39;ts_day&#39;]</text>\n</g>\n<!-- 6_selector&#45;&gt;6 -->\n<g id=\"edge11\" class=\"edge\">\n<title>6_selector&#45;&gt;6</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1950.65,-359.7C1952.3,-351.98 1954.29,-342.71 1956.13,-334.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1959.61,-334.62 1958.28,-324.1 1952.76,-333.15 1959.61,-334.62\"/>\n</g>\n<!-- 8 -->\n<g id=\"node17\" class=\"node\">\n<title>8</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"2113.94\" cy=\"-306\" rx=\"66.89\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"2113.94\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\">SelectionOp</text>\n</g>\n<!-- 8&#45;&gt;7 -->\n<g id=\"edge13\" class=\"edge\">\n<title>8&#45;&gt;7</title>\n<path fill=\"none\" stroke=\"black\" d=\"M2095.93,-288.41C2086.25,-279.5 2074.16,-268.36 2063.52,-258.56\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"2065.82,-255.91 2056.09,-251.71 2061.07,-261.06 2065.82,-255.91\"/>\n</g>\n<!-- 7_selector -->\n<g id=\"node18\" class=\"node\">\n<title>7_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"2492.94\" cy=\"-306\" rx=\"293.55\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"2492.94\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;brand&#39;, &#39;user_id&#39;, &#39;product_id&#39;, &#39;cat_2&#39;, &#39;ts_weekday&#39;, &#39;ts_day&#39;]</text>\n</g>\n<!-- 7_selector&#45;&gt;7 -->\n<g id=\"edge14\" class=\"edge\">\n<title>7_selector&#45;&gt;7</title>\n<path fill=\"none\" stroke=\"black\" d=\"M2389.7,-289.12C2305.89,-276.22 2189.62,-258.34 2114.09,-246.72\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"2114.59,-243.25 2104.17,-245.19 2113.52,-250.17 2114.59,-243.25\"/>\n</g>\n<!-- 8_selector -->\n<g id=\"node19\" class=\"node\">\n<title>8_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"2309.94\" cy=\"-378\" rx=\"51.19\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"2309.94\" y=\"-374.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;target&#39;]</text>\n</g>\n<!-- 8_selector&#45;&gt;8 -->\n<g id=\"edge15\" class=\"edge\">\n<title>8_selector&#45;&gt;8</title>\n<path fill=\"none\" stroke=\"black\" d=\"M2275.35,-364.65C2243.91,-353.42 2197.21,-336.74 2162.12,-324.21\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"2163.15,-320.86 2152.55,-320.79 2160.79,-327.45 2163.15,-320.86\"/>\n</g>\n<!-- 9 -->\n<g id=\"node20\" class=\"node\">\n<title>9</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"994.94\" cy=\"-450\" rx=\"66.89\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"994.94\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\">SelectionOp</text>\n</g>\n<!-- 10 -->\n<g id=\"node22\" class=\"node\">\n<title>10</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1070.94\" cy=\"-378\" rx=\"106.68\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1070.94\" y=\"-374.3\" font-family=\"Times,serif\" font-size=\"14.00\">user_id_cross_maker</text>\n</g>\n<!-- 9&#45;&gt;10 -->\n<g id=\"edge17\" class=\"edge\">\n<title>9&#45;&gt;10</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1012.95,-432.41C1022.56,-423.56 1034.55,-412.52 1045.13,-402.77\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1047.55,-405.31 1052.53,-395.96 1042.81,-400.16 1047.55,-405.31\"/>\n</g>\n<!-- 9_selector -->\n<g id=\"node21\" class=\"node\">\n<title>9_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"975.94\" cy=\"-522\" rx=\"211.76\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"975.94\" y=\"-518.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;product_id&#39;, &#39;brand&#39;, &#39;ts_hour&#39;, &#39;ts_minute&#39;]</text>\n</g>\n<!-- 9_selector&#45;&gt;9 -->\n<g id=\"edge16\" class=\"edge\">\n<title>9_selector&#45;&gt;9</title>\n<path fill=\"none\" stroke=\"black\" d=\"M980.64,-503.7C982.76,-495.9 985.31,-486.51 987.66,-477.83\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"991.06,-478.67 990.3,-468.1 984.3,-476.84 991.06,-478.67\"/>\n</g>\n<!-- 10&#45;&gt;15 -->\n<g id=\"edge24\" class=\"edge\">\n<title>10&#45;&gt;15</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1070.94,-359.7C1070.94,-351.98 1070.94,-342.71 1070.94,-334.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1074.44,-334.1 1070.94,-324.1 1067.44,-334.1 1074.44,-334.1\"/>\n</g>\n<!-- 12 -->\n<g id=\"node23\" class=\"node\">\n<title>12</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1146.94\" cy=\"-450\" rx=\"66.89\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1146.94\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\">SelectionOp</text>\n</g>\n<!-- 12&#45;&gt;10 -->\n<g id=\"edge18\" class=\"edge\">\n<title>12&#45;&gt;10</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1128.93,-432.41C1119.33,-423.56 1107.33,-412.52 1096.75,-402.77\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1099.08,-400.16 1089.35,-395.96 1094.34,-405.31 1099.08,-400.16\"/>\n</g>\n<!-- 10_selector -->\n<g id=\"node24\" class=\"node\">\n<title>10_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1443.94\" cy=\"-450\" rx=\"211.76\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1443.94\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;product_id&#39;, &#39;brand&#39;, &#39;ts_hour&#39;, &#39;ts_minute&#39;]</text>\n</g>\n<!-- 10_selector&#45;&gt;10 -->\n<g id=\"edge19\" class=\"edge\">\n<title>10_selector&#45;&gt;10</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1360.7,-433.38C1297.55,-421.53 1211.33,-405.35 1149.39,-393.72\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1149.75,-390.23 1139.28,-391.82 1148.46,-397.11 1149.75,-390.23\"/>\n</g>\n<!-- 12_selector -->\n<g id=\"node25\" class=\"node\">\n<title>12_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1261.94\" cy=\"-522\" rx=\"55.79\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1261.94\" y=\"-518.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;user_id&#39;]</text>\n</g>\n<!-- 12_selector&#45;&gt;12 -->\n<g id=\"edge21\" class=\"edge\">\n<title>12_selector&#45;&gt;12</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1236.99,-505.81C1220.79,-495.95 1199.45,-482.96 1181.61,-472.1\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1183.01,-468.85 1172.64,-466.64 1179.37,-474.83 1183.01,-468.85\"/>\n</g>\n<!-- 13_selector -->\n<g id=\"node26\" class=\"node\">\n<title>13_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1666.94\" cy=\"-234\" rx=\"51.19\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1666.94\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;target&#39;]</text>\n</g>\n<!-- 13_selector&#45;&gt;13 -->\n<g id=\"edge22\" class=\"edge\">\n<title>13_selector&#45;&gt;13</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1666.94,-215.7C1666.94,-207.98 1666.94,-198.71 1666.94,-190.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1670.44,-190.1 1666.94,-180.1 1663.44,-190.1 1670.44,-190.1\"/>\n</g>\n<!-- 14 -->\n<g id=\"node27\" class=\"node\">\n<title>14</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1383.94\" cy=\"-306\" rx=\"66.89\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1383.94\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\">SelectionOp</text>\n</g>\n<!-- 18 -->\n<g id=\"node32\" class=\"node\">\n<title>18</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1523.94\" cy=\"-234\" rx=\"62.29\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1523.94\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">FillMissing</text>\n</g>\n<!-- 14&#45;&gt;18 -->\n<g id=\"edge30\" class=\"edge\">\n<title>14&#45;&gt;18</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1414.32,-289.81C1435.01,-279.47 1462.58,-265.68 1484.9,-254.52\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1486.73,-257.52 1494.11,-249.92 1483.6,-251.26 1486.73,-257.52\"/>\n</g>\n<!-- 14_selector -->\n<g id=\"node28\" class=\"node\">\n<title>14_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1399.94\" cy=\"-378\" rx=\"204.16\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1399.94\" y=\"-374.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;price&#39;, &#39;ts_weekday&#39;, &#39;ts_day&#39;, &#39;ts_month&#39;]</text>\n</g>\n<!-- 14_selector&#45;&gt;14 -->\n<g id=\"edge23\" class=\"edge\">\n<title>14_selector&#45;&gt;14</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1395.99,-359.7C1394.23,-351.98 1392.11,-342.71 1390.14,-334.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1393.49,-333.07 1387.85,-324.1 1386.67,-334.63 1393.49,-333.07\"/>\n</g>\n<!-- 16 -->\n<g id=\"node29\" class=\"node\">\n<title>16</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"370.94\" cy=\"-450\" rx=\"66.89\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"370.94\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\">SelectionOp</text>\n</g>\n<!-- 16&#45;&gt;17 -->\n<g id=\"edge26\" class=\"edge\">\n<title>16&#45;&gt;17</title>\n<path fill=\"none\" stroke=\"black\" d=\"M370.94,-431.7C370.94,-423.98 370.94,-414.71 370.94,-406.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"374.44,-406.1 370.94,-396.1 367.44,-406.1 374.44,-406.1\"/>\n</g>\n<!-- 16_selector -->\n<g id=\"node30\" class=\"node\">\n<title>16_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"499.94\" cy=\"-522\" rx=\"226.66\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"499.94\" y=\"-518.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;ts_hour&#39;, &#39;ts_weekday&#39;, &#39;cat_0&#39;, &#39;cat_1&#39;, &#39;cat_2&#39;]</text>\n</g>\n<!-- 16_selector&#45;&gt;16 -->\n<g id=\"edge25\" class=\"edge\">\n<title>16_selector&#45;&gt;16</title>\n<path fill=\"none\" stroke=\"black\" d=\"M468.72,-504.05C450.44,-494.13 427.24,-481.55 408.03,-471.13\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"409.7,-468.05 399.24,-466.36 406.36,-474.2 409.7,-468.05\"/>\n</g>\n<!-- 17_selector -->\n<g id=\"node31\" class=\"node\">\n<title>17_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"682.94\" cy=\"-450\" rx=\"226.66\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"682.94\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;ts_hour&#39;, &#39;ts_weekday&#39;, &#39;cat_0&#39;, &#39;cat_1&#39;, &#39;cat_2&#39;]</text>\n</g>\n<!-- 17_selector&#45;&gt;17 -->\n<g id=\"edge29\" class=\"edge\">\n<title>17_selector&#45;&gt;17</title>\n<path fill=\"none\" stroke=\"black\" d=\"M610.97,-432.85C561.88,-421.84 496.95,-407.27 447.01,-396.07\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"447.75,-392.65 437.23,-393.87 446.22,-399.48 447.75,-392.65\"/>\n</g>\n<!-- 18&#45;&gt;19 -->\n<g id=\"edge32\" class=\"edge\">\n<title>18&#45;&gt;19</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1523.94,-215.7C1523.94,-207.98 1523.94,-198.71 1523.94,-190.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1527.44,-190.1 1523.94,-180.1 1520.44,-190.1 1527.44,-190.1\"/>\n</g>\n<!-- 18_selector -->\n<g id=\"node33\" class=\"node\">\n<title>18_selector</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"1672.94\" cy=\"-306\" rx=\"204.16\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"1672.94\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\">[&#39;price&#39;, &#39;ts_weekday&#39;, &#39;ts_day&#39;, &#39;ts_month&#39;]</text>\n</g>\n<!-- 18_selector&#45;&gt;18 -->\n<g id=\"edge31\" class=\"edge\">\n<title>18_selector&#45;&gt;18</title>\n<path fill=\"none\" stroke=\"black\" d=\"M1637.25,-288.23C1615.31,-277.92 1587.13,-264.69 1564.37,-253.99\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"1565.84,-250.81 1555.3,-249.73 1562.86,-257.15 1565.84,-250.81\"/>\n</g>\n</g>\n</svg>\n",
      "text/plain": [
       "<graphviz.dot.Digraph at 0x7f05e75470d0>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Executing the workflow\n",
    "\n",
    "After having defined the workflow, calling the `fit()` method will start the actual computation to record the required statistics from the training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/numba/cuda/compiler.py:865: NumbaPerformanceWarning: Grid size (1) < 2 * SM count (112) will likely result in GPU under utilization due to low occupancy.\n",
      "  warn(NumbaPerformanceWarning(msg))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 13.3 s, sys: 10.1 s, total: 23.4 s\n",
      "Wall time: 26.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "time_preproc_start = time()\n",
    "proc.fit(train_dataset)\n",
    "time_preproc = time()-time_preproc_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['ts_hour_user_id_brand_cross',\n",
       " 'ts_weekday_user_id_brand_cross',\n",
       " 'cat_0_user_id_brand_cross',\n",
       " 'cat_1_user_id_brand_cross',\n",
       " 'cat_2_user_id_brand_cross',\n",
       " 'product_id_user_id_cross',\n",
       " 'brand_user_id_cross',\n",
       " 'ts_hour_user_id_cross',\n",
       " 'ts_minute_user_id_cross']"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_feats.output_columns.names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['ts_hour_user_id_brand_cross',\n",
       " 'ts_weekday_user_id_brand_cross',\n",
       " 'cat_0_user_id_brand_cross',\n",
       " 'cat_1_user_id_brand_cross',\n",
       " 'cat_2_user_id_brand_cross',\n",
       " 'product_id_user_id_cross',\n",
       " 'brand_user_id_cross',\n",
       " 'ts_hour_user_id_cross',\n",
       " 'ts_minute_user_id_cross',\n",
       " 'price',\n",
       " 'ts_weekday',\n",
       " 'ts_day',\n",
       " 'ts_month',\n",
       " 'TE_brand_target_TE',\n",
       " 'TE_user_id_target_TE',\n",
       " 'TE_product_id_target_TE',\n",
       " 'TE_cat_2_target_TE',\n",
       " 'TE_ts_weekday_ts_day_target_TE',\n",
       " 'target']"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.output_columns.names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "CAT_FEATS = ['ts_hour_user_id_brand_cross',\n",
    " 'ts_weekday_user_id_brand_cross',\n",
    " 'cat_0_user_id_brand_cross',\n",
    " 'cat_1_user_id_brand_cross',\n",
    " 'cat_2_user_id_brand_cross',\n",
    " 'product_id_user_id_cross',\n",
    " 'brand_user_id_cross',\n",
    " 'ts_hour_user_id_cross',\n",
    " 'ts_minute_user_id_cross',]\n",
    "\n",
    "CON_FEATS = ['price',\n",
    " 'ts_weekday',\n",
    " 'ts_day',\n",
    " 'ts_month',\n",
    " 'TE_brand_target_TE',\n",
    " 'TE_user_id_target_TE',\n",
    " 'TE_product_id_target_TE',\n",
    " 'TE_cat_2_target_TE',\n",
    " 'TE_ts_weekday_ts_day_target_TE']\n",
    "\n",
    "dict_dtypes = {}\n",
    "for col in CAT_FEATS:\n",
    "    dict_dtypes[col] = np.int64\n",
    "for col in CON_FEATS:\n",
    "    dict_dtypes[col] = np.float32\n",
    "\n",
    "dict_dtypes['target'] = np.float32"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we call the `transform()` method to transform the datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_train_dir = os.path.join(output_path, 'train/')\n",
    "output_valid_dir = os.path.join(output_path, 'valid/')\n",
    "output_test_dir = os.path.join(output_path, 'test/')\n",
    "! rm -rf $output_train_dir && mkdir -p $output_train_dir\n",
    "! rm -rf $output_valid_dir && mkdir -p $output_valid_dir\n",
    "! rm -rf $output_test_dir && mkdir -p $output_test_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/nvtabular/nvtabular/io/dask.py:375: UserWarning: A global dask.distributed client has been detected, but the single-threaded scheduler will be used for this write operation. Please use the `client` argument to initialize a `Dataset` and/or `Workflow` object with distributed-execution enabled.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.6 s, sys: 3.29 s, total: 4.89 s\n",
      "Wall time: 5.79 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "time_preproc_start = time()\n",
    "proc.transform(train_dataset).to_parquet(output_path=output_train_dir, dtypes=dict_dtypes,\n",
    "                                         shuffle=nvt.io.Shuffle.PER_PARTITION,\n",
    "                                         cats=CAT_FEATS,\n",
    "                                         conts=CON_FEATS,\n",
    "                                         labels=['target'])\n",
    "time_preproc += time()-time_preproc_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 366131\n",
      "-rw-r--r-- 1 root dip        47 Nov 16 23:04 _file_list.txt\n",
      "-rw-r--r-- 1 root dip     18283 Nov 16 23:04 _metadata\n",
      "-rw-r--r-- 1 root dip      1045 Nov 16 23:04 _metadata.json\n",
      "-rw-r--r-- 1 root dip 706364298 Nov 16 23:04 part_0.parquet\n",
      "-rw-r--r-- 1 root dip      7975 Nov 16 23:04 schema.pbtxt\n"
     ]
    }
   ],
   "source": [
    "!ls -l $output_train_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.06 s, sys: 1.57 s, total: 2.63 s\n",
      "Wall time: 2.83 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "time_preproc_start = time()\n",
    "proc.transform(valid_dataset).to_parquet(output_path=output_valid_dir, dtypes=dict_dtypes,\n",
    "                                         shuffle=nvt.io.Shuffle.PER_PARTITION,\n",
    "                                         cats=CAT_FEATS,\n",
    "                                         conts=CON_FEATS,\n",
    "                                         labels=['target'])\n",
    "time_preproc += time()-time_preproc_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 100979\n",
      "-rw-r--r-- 1 root dip       47 Nov 16 23:04 _file_list.txt\n",
      "-rw-r--r-- 1 root dip     8983 Nov 16 23:04 _metadata\n",
      "-rw-r--r-- 1 root dip     1045 Nov 16 23:04 _metadata.json\n",
      "-rw-r--r-- 1 root dip 92826604 Nov 16 23:04 part_0.parquet\n",
      "-rw-r--r-- 1 root dip     7975 Nov 16 23:04 schema.pbtxt\n"
     ]
    }
   ],
   "source": [
    "!ls -l $output_valid_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.05 s, sys: 1.64 s, total: 2.69 s\n",
      "Wall time: 2.75 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "time_preproc_start = time()\n",
    "proc.transform(test_dataset).to_parquet(output_path=output_test_dir, dtypes=dict_dtypes,\n",
    "                                         shuffle=nvt.io.Shuffle.PER_PARTITION,\n",
    "                                         cats=CAT_FEATS,\n",
    "                                         conts=CON_FEATS,\n",
    "                                         labels=['target'])\n",
    "time_preproc += time()-time_preproc_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38.198790550231934"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "time_preproc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verify the preprocessed data\n",
    "\n",
    "Let's quickly read the data back and verify that all fields have the expected format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_file_list.txt\t_metadata  _metadata.json  part_0.parquet  schema.pbtxt\n"
     ]
    }
   ],
   "source": [
    "!ls $output_train_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ts_hour_user_id_brand_cross</th>\n",
       "      <th>ts_weekday_user_id_brand_cross</th>\n",
       "      <th>cat_0_user_id_brand_cross</th>\n",
       "      <th>cat_1_user_id_brand_cross</th>\n",
       "      <th>cat_2_user_id_brand_cross</th>\n",
       "      <th>product_id_user_id_cross</th>\n",
       "      <th>brand_user_id_cross</th>\n",
       "      <th>ts_hour_user_id_cross</th>\n",
       "      <th>ts_minute_user_id_cross</th>\n",
       "      <th>price</th>\n",
       "      <th>ts_weekday</th>\n",
       "      <th>ts_day</th>\n",
       "      <th>ts_month</th>\n",
       "      <th>TE_brand_target_TE</th>\n",
       "      <th>TE_user_id_target_TE</th>\n",
       "      <th>TE_product_id_target_TE</th>\n",
       "      <th>TE_cat_2_target_TE</th>\n",
       "      <th>TE_ts_weekday_ts_day_target_TE</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>817883</td>\n",
       "      <td>908980</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1085846</td>\n",
       "      <td>855303</td>\n",
       "      <td>144463</td>\n",
       "      <td>5417827</td>\n",
       "      <td>-0.725652</td>\n",
       "      <td>-0.502085</td>\n",
       "      <td>0.235441</td>\n",
       "      <td>1.314784</td>\n",
       "      <td>0.199788</td>\n",
       "      <td>0.325234</td>\n",
       "      <td>0.227236</td>\n",
       "      <td>-1.284867</td>\n",
       "      <td>0.387063</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4152058</td>\n",
       "      <td>2732403</td>\n",
       "      <td>1210052</td>\n",
       "      <td>712779</td>\n",
       "      <td>0</td>\n",
       "      <td>1360101</td>\n",
       "      <td>954962</td>\n",
       "      <td>731363</td>\n",
       "      <td>2230166</td>\n",
       "      <td>-0.836849</td>\n",
       "      <td>-0.007929</td>\n",
       "      <td>-0.802043</td>\n",
       "      <td>-0.864342</td>\n",
       "      <td>0.355313</td>\n",
       "      <td>0.266837</td>\n",
       "      <td>0.255486</td>\n",
       "      <td>-1.285741</td>\n",
       "      <td>0.420646</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3204608</td>\n",
       "      <td>274365</td>\n",
       "      <td>30730</td>\n",
       "      <td>31144</td>\n",
       "      <td>25505</td>\n",
       "      <td>29457</td>\n",
       "      <td>32720</td>\n",
       "      <td>3039842</td>\n",
       "      <td>3062261</td>\n",
       "      <td>-0.184922</td>\n",
       "      <td>-0.007929</td>\n",
       "      <td>1.618752</td>\n",
       "      <td>-0.864342</td>\n",
       "      <td>0.466206</td>\n",
       "      <td>0.237990</td>\n",
       "      <td>0.414308</td>\n",
       "      <td>0.459563</td>\n",
       "      <td>0.239809</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>3464677</td>\n",
       "      <td>0</td>\n",
       "      <td>2467278</td>\n",
       "      <td>2493129</td>\n",
       "      <td>-0.841169</td>\n",
       "      <td>-0.007929</td>\n",
       "      <td>0.465993</td>\n",
       "      <td>-0.666240</td>\n",
       "      <td>-1.285569</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.318047</td>\n",
       "      <td>-1.288352</td>\n",
       "      <td>0.376334</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2665639</td>\n",
       "      <td>66327</td>\n",
       "      <td>19261</td>\n",
       "      <td>19397</td>\n",
       "      <td>16204</td>\n",
       "      <td>2497109</td>\n",
       "      <td>16447</td>\n",
       "      <td>2620458</td>\n",
       "      <td>2349810</td>\n",
       "      <td>3.510283</td>\n",
       "      <td>0.486228</td>\n",
       "      <td>0.581269</td>\n",
       "      <td>-0.666240</td>\n",
       "      <td>0.446405</td>\n",
       "      <td>0.533477</td>\n",
       "      <td>0.492186</td>\n",
       "      <td>0.459563</td>\n",
       "      <td>0.388496</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   ts_hour_user_id_brand_cross  ts_weekday_user_id_brand_cross  \\\n",
       "0                       817883                          908980   \n",
       "1                      4152058                         2732403   \n",
       "2                      3204608                          274365   \n",
       "3                            0                               0   \n",
       "4                      2665639                           66327   \n",
       "\n",
       "   cat_0_user_id_brand_cross  cat_1_user_id_brand_cross  \\\n",
       "0                          0                          0   \n",
       "1                    1210052                     712779   \n",
       "2                      30730                      31144   \n",
       "3                          0                          0   \n",
       "4                      19261                      19397   \n",
       "\n",
       "   cat_2_user_id_brand_cross  product_id_user_id_cross  brand_user_id_cross  \\\n",
       "0                          0                   1085846               855303   \n",
       "1                          0                   1360101               954962   \n",
       "2                      25505                     29457                32720   \n",
       "3                          0                   3464677                    0   \n",
       "4                      16204                   2497109                16447   \n",
       "\n",
       "   ts_hour_user_id_cross  ts_minute_user_id_cross     price  ts_weekday  \\\n",
       "0                 144463                  5417827 -0.725652   -0.502085   \n",
       "1                 731363                  2230166 -0.836849   -0.007929   \n",
       "2                3039842                  3062261 -0.184922   -0.007929   \n",
       "3                2467278                  2493129 -0.841169   -0.007929   \n",
       "4                2620458                  2349810  3.510283    0.486228   \n",
       "\n",
       "     ts_day  ts_month  TE_brand_target_TE  TE_user_id_target_TE  \\\n",
       "0  0.235441  1.314784            0.199788              0.325234   \n",
       "1 -0.802043 -0.864342            0.355313              0.266837   \n",
       "2  1.618752 -0.864342            0.466206              0.237990   \n",
       "3  0.465993 -0.666240           -1.285569              0.390281   \n",
       "4  0.581269 -0.666240            0.446405              0.533477   \n",
       "\n",
       "   TE_product_id_target_TE  TE_cat_2_target_TE  \\\n",
       "0                 0.227236           -1.284867   \n",
       "1                 0.255486           -1.285741   \n",
       "2                 0.414308            0.459563   \n",
       "3                 0.318047           -1.288352   \n",
       "4                 0.492186            0.459563   \n",
       "\n",
       "   TE_ts_weekday_ts_day_target_TE  target  \n",
       "0                        0.387063     0.0  \n",
       "1                        0.420646     0.0  \n",
       "2                        0.239809     0.0  \n",
       "3                        0.376334     0.0  \n",
       "4                        0.388496     0.0  "
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nvtdata = pd.read_parquet(output_train_dir+'/part_0.parquet')\n",
    "nvtdata.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_file_list.txt\t_metadata  _metadata.json  part_0.parquet  schema.pbtxt\n"
     ]
    }
   ],
   "source": [
    "!ls $output_valid_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ts_hour_user_id_brand_cross</th>\n",
       "      <th>ts_weekday_user_id_brand_cross</th>\n",
       "      <th>cat_0_user_id_brand_cross</th>\n",
       "      <th>cat_1_user_id_brand_cross</th>\n",
       "      <th>cat_2_user_id_brand_cross</th>\n",
       "      <th>product_id_user_id_cross</th>\n",
       "      <th>brand_user_id_cross</th>\n",
       "      <th>ts_hour_user_id_cross</th>\n",
       "      <th>ts_minute_user_id_cross</th>\n",
       "      <th>price</th>\n",
       "      <th>ts_weekday</th>\n",
       "      <th>ts_day</th>\n",
       "      <th>ts_month</th>\n",
       "      <th>TE_brand_target_TE</th>\n",
       "      <th>TE_user_id_target_TE</th>\n",
       "      <th>TE_product_id_target_TE</th>\n",
       "      <th>TE_cat_2_target_TE</th>\n",
       "      <th>TE_ts_weekday_ts_day_target_TE</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1.107537</td>\n",
       "      <td>0.980384</td>\n",
       "      <td>-0.225663</td>\n",
       "      <td>-0.468138</td>\n",
       "      <td>0.372078</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.427259</td>\n",
       "      <td>0.390176</td>\n",
       "      <td>0.353950</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.840005</td>\n",
       "      <td>0.980384</td>\n",
       "      <td>-0.225663</td>\n",
       "      <td>-0.468138</td>\n",
       "      <td>0.364968</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.320797</td>\n",
       "      <td>-1.284867</td>\n",
       "      <td>0.352739</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.327548</td>\n",
       "      <td>-1.490397</td>\n",
       "      <td>-0.802043</td>\n",
       "      <td>-0.468138</td>\n",
       "      <td>0.466705</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.498779</td>\n",
       "      <td>0.459734</td>\n",
       "      <td>0.380590</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2.291189</td>\n",
       "      <td>-1.490397</td>\n",
       "      <td>-1.608975</td>\n",
       "      <td>-0.468138</td>\n",
       "      <td>0.446405</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.303992</td>\n",
       "      <td>0.459563</td>\n",
       "      <td>0.408594</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2677742</td>\n",
       "      <td>1812225</td>\n",
       "      <td>1033276</td>\n",
       "      <td>3759307</td>\n",
       "      <td>2828489</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.738273</td>\n",
       "      <td>-0.007929</td>\n",
       "      <td>1.157649</td>\n",
       "      <td>-0.468138</td>\n",
       "      <td>0.277334</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.326134</td>\n",
       "      <td>0.257637</td>\n",
       "      <td>0.409076</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   ts_hour_user_id_brand_cross  ts_weekday_user_id_brand_cross  \\\n",
       "0                            0                               0   \n",
       "1                            0                               0   \n",
       "2                            0                               0   \n",
       "3                            0                               0   \n",
       "4                            0                               0   \n",
       "\n",
       "   cat_0_user_id_brand_cross  cat_1_user_id_brand_cross  \\\n",
       "0                          0                          0   \n",
       "1                          0                          0   \n",
       "2                          0                          0   \n",
       "3                          0                          0   \n",
       "4                    2677742                    1812225   \n",
       "\n",
       "   cat_2_user_id_brand_cross  product_id_user_id_cross  brand_user_id_cross  \\\n",
       "0                          0                         0                    0   \n",
       "1                          0                         0                    0   \n",
       "2                          0                         0                    0   \n",
       "3                          0                         0                    0   \n",
       "4                    1033276                   3759307              2828489   \n",
       "\n",
       "   ts_hour_user_id_cross  ts_minute_user_id_cross     price  ts_weekday  \\\n",
       "0                      0                        0  1.107537    0.980384   \n",
       "1                      0                        0 -0.840005    0.980384   \n",
       "2                      0                        0 -0.327548   -1.490397   \n",
       "3                      0                        0  2.291189   -1.490397   \n",
       "4                      0                        0 -0.738273   -0.007929   \n",
       "\n",
       "     ts_day  ts_month  TE_brand_target_TE  TE_user_id_target_TE  \\\n",
       "0 -0.225663 -0.468138            0.372078              0.390281   \n",
       "1 -0.225663 -0.468138            0.364968              0.390281   \n",
       "2 -0.802043 -0.468138            0.466705              0.390281   \n",
       "3 -1.608975 -0.468138            0.446405              0.390281   \n",
       "4  1.157649 -0.468138            0.277334              0.390281   \n",
       "\n",
       "   TE_product_id_target_TE  TE_cat_2_target_TE  \\\n",
       "0                 0.427259            0.390176   \n",
       "1                 0.320797           -1.284867   \n",
       "2                 0.498779            0.459734   \n",
       "3                 0.303992            0.459563   \n",
       "4                 0.326134            0.257637   \n",
       "\n",
       "   TE_ts_weekday_ts_day_target_TE  target  \n",
       "0                        0.353950     0.0  \n",
       "1                        0.352739     0.0  \n",
       "2                        0.380590     1.0  \n",
       "3                        0.408594     0.0  \n",
       "4                        0.409076     0.0  "
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nvtdata_valid = pd.read_parquet(output_valid_dir+'/part_0.parquet')\n",
    "nvtdata_valid.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2359020"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(nvtdata_valid['ts_hour_user_id_brand_cross']==0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2461719"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(nvtdata_valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting the embedding size\n",
    "\n",
    "Next, we need to get the embedding size for the categorical variables. This is an important input for defining the embedding table size to be used by HugeCTR."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'ts_hour_user_id_brand_cross': (4427037, 512),\n",
       " 'ts_weekday_user_id_brand_cross': (3961156, 512),\n",
       " 'cat_0_user_id_brand_cross': (2877223, 512),\n",
       " 'cat_1_user_id_brand_cross': (2890639, 512),\n",
       " 'cat_2_user_id_brand_cross': (2159304, 512),\n",
       " 'product_id_user_id_cross': (4398425, 512),\n",
       " 'brand_user_id_cross': (3009092, 512),\n",
       " 'ts_hour_user_id_cross': (3999369, 512),\n",
       " 'ts_minute_user_id_cross': (5931061, 512)}"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeddings = ops.get_embedding_sizes(proc)\n",
    "embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4427037, 3961156, 2877223, 2890639, 2159304, 4398425, 3009092, 3999369, 5931061]\n"
     ]
    }
   ],
   "source": [
    "print([embeddings[x][0] for x in cat_feats.output_columns.names])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['ts_hour_user_id_brand_cross',\n",
       " 'ts_weekday_user_id_brand_cross',\n",
       " 'cat_0_user_id_brand_cross',\n",
       " 'cat_1_user_id_brand_cross',\n",
       " 'cat_2_user_id_brand_cross',\n",
       " 'product_id_user_id_cross',\n",
       " 'brand_user_id_cross',\n",
       " 'ts_hour_user_id_cross',\n",
       " 'ts_minute_user_id_cross']"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_feats.output_columns.names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[4427037, 3961156, 2877223, 2890639, 2159304, 4398425, 3009092, 3999369, 5931061]'"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_size_str = \"{}\".format([embeddings[x][0] for x in cat_feats.output_columns.names])\n",
    "embedding_size_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_con_feates = len(CON_FEATS)\n",
    "num_con_feates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4427037, 3961156, 2877223, 2890639, 2159304, 4398425, 3009092, 3999369, 5931061]\n"
     ]
    }
   ],
   "source": [
    "print([embeddings[x][0] for x in cat_feats.output_columns.names])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll shutdown our Dask client from earlier to free up some memory so that we can share it with HugeCTR."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "client.shutdown()\n",
    "cluster.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing the training Python script for HugeCTR\n",
    "\n",
    "The HugeCTR model can be defined by Python API. The following Python script defines a DLRM model and specifies the training resources. \n",
    "\n",
    "Several parameters that need to be edited to match this dataset are:\n",
    "\n",
    "- `slot_size_array`: cardinalities for the categorical variables\n",
    "- `dense_dim`: number of dense features\n",
    "- `slot_num`: number of categorical variables\n",
    "\n",
    "The model graph can be saved into a JSON file by calling `model.graph_to_json`, which will be used for inference afterwards.\n",
    "\n",
    "In the following code, we train the network using 8 GPUs and a workspace of 4000 MB per GPU. Note that the total embedding size is `33653306*128*4/(1024**3)` = 16.432 GB."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting hugectr_dlrm_ecommerce.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile hugectr_dlrm_ecommerce.py\n",
    "import hugectr\n",
    "from mpi4py import MPI\n",
    "solver = hugectr.CreateSolver(max_eval_batches = 2720,\n",
    "                              batchsize_eval = 16384,\n",
    "                              batchsize = 16384,\n",
    "                              lr = 0.1,\n",
    "                              warmup_steps = 8000,\n",
    "                              decay_start = 48000,\n",
    "                              decay_steps = 24000,\n",
    "                              vvgpu = [[0,1,2,3,4,5,6,7]],\n",
    "                              repeat_dataset = True,\n",
    "                              i64_input_key = True)\n",
    "reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,\n",
    "                                  source = [\"./nvtabular_temp/output/train/_file_list.txt\"],\n",
    "                                  eval_source = \"./nvtabular_temp/output/valid/_file_list.txt\",\n",
    "                                  check_type = hugectr.Check_t.Non,\n",
    "                                  slot_size_array = [4427037, 3961156, 2877223, 2890639, 2159304, 4398425, 3009092, 3999369, 5931061])\n",
    "optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.SGD,\n",
    "                                    update_type = hugectr.Update_t.Local,\n",
    "                                    atomic_update = True)\n",
    "model = hugectr.Model(solver, reader, optimizer)\n",
    "model.add(hugectr.Input(label_dim = 1, label_name = \"label\",\n",
    "                        dense_dim = 9, dense_name = \"dense\",\n",
    "                        data_reader_sparse_param_array = \n",
    "                        [hugectr.DataReaderSparseParam(\"data1\", 1, True, 9)]))\n",
    "model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,\n",
    "                            workspace_size_per_gpu_in_mb = 4000,\n",
    "                            embedding_vec_size = 128,\n",
    "                            combiner = 'sum',\n",
    "                            sparse_embedding_name = \"sparse_embedding1\",\n",
    "                            bottom_name = \"data1\",\n",
    "                            optimizer = optimizer))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"dense\"],\n",
    "                            top_names = [\"fc1\"],\n",
    "                            num_output=512))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,\n",
    "                            bottom_names = [\"fc1\"],\n",
    "                            top_names = [\"relu1\"]))                           \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"relu1\"],\n",
    "                            top_names = [\"fc2\"],\n",
    "                            num_output=256))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,\n",
    "                            bottom_names = [\"fc2\"],\n",
    "                            top_names = [\"relu2\"]))                            \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"relu2\"],\n",
    "                            top_names = [\"fc3\"],\n",
    "                            num_output=128))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,\n",
    "                            bottom_names = [\"fc3\"],\n",
    "                            top_names = [\"relu3\"]))                              \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Interaction,\n",
    "                            bottom_names = [\"relu3\",\"sparse_embedding1\"],\n",
    "                            top_names = [\"interaction1\"]))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"interaction1\"],\n",
    "                            top_names = [\"fc4\"],\n",
    "                            num_output=1024))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,\n",
    "                            bottom_names = [\"fc4\"],\n",
    "                            top_names = [\"relu4\"]))                              \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"relu4\"],\n",
    "                            top_names = [\"fc5\"],\n",
    "                            num_output=1024))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,\n",
    "                            bottom_names = [\"fc5\"],\n",
    "                            top_names = [\"relu5\"]))                              \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"relu5\"],\n",
    "                            top_names = [\"fc6\"],\n",
    "                            num_output=512))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,\n",
    "                            bottom_names = [\"fc6\"],\n",
    "                            top_names = [\"relu6\"]))                               \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"relu6\"],\n",
    "                            top_names = [\"fc7\"],\n",
    "                            num_output=256))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,\n",
    "                            bottom_names = [\"fc7\"],\n",
    "                            top_names = [\"relu7\"]))                                                                              \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"relu7\"],\n",
    "                            top_names = [\"fc8\"],\n",
    "                            num_output=1))                                                                                           \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,\n",
    "                            bottom_names = [\"fc8\", \"label\"],\n",
    "                            top_names = [\"loss\"]))\n",
    "model.compile()\n",
    "model.summary()\n",
    "model.graph_to_json(graph_config_file = \"dlrm_ecommerce.json\")\n",
    "model.fit(max_iter = 12000, display = 1000, eval_interval = 3000, snapshot = 10000, snapshot_prefix = \"./\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HugeCTR training\n",
    "\n",
    "Now we are ready to train a DLRM model with HugeCTR.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HugeCTR Version: 3.2\n",
      "====================================================Model Init=====================================================\n",
      "[HUGECTR][23:17:45][INFO][RANK0]: Global seed is 1985961998\n",
      "[HUGECTR][23:17:45][INFO][RANK0]: Device to NUMA mapping:\n",
      "  GPU 0 ->  node 0\n",
      "  GPU 1 ->  node 0\n",
      "  GPU 2 ->  node 0\n",
      "  GPU 3 ->  node 0\n",
      "  GPU 4 ->  node 1\n",
      "  GPU 5 ->  node 1\n",
      "  GPU 6 ->  node 1\n",
      "  GPU 7 ->  node 1\n",
      "\n",
      "[HUGECTR][23:17:54][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Start all2all warmup\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: End all2all warmup\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Using All-reduce algorithm: NCCL\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Device 0: Tesla P100-SXM2-16GB\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Device 1: Tesla P100-SXM2-16GB\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Device 2: Tesla P100-SXM2-16GB\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Device 3: Tesla P100-SXM2-16GB\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Device 4: Tesla P100-SXM2-16GB\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Device 5: Tesla P100-SXM2-16GB\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Device 6: Tesla P100-SXM2-16GB\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: Device 7: Tesla P100-SXM2-16GB\n",
      "[HUGECTR][23:17:54][INFO][RANK0]: num of DataReader workers: 8\n",
      "[HUGECTR][23:17:55][INFO][RANK0]: Vocabulary size: 33653306\n",
      "[HUGECTR][23:17:55][INFO][RANK0]: max_vocabulary_size_per_gpu_=8192000\n",
      "[HUGECTR][23:17:58][INFO][RANK0]: Graph analysis to resolve tensor dependency\n",
      "===================================================Model Compile===================================================\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu0 start to init embedding\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu1 start to init embedding\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu6 start to init embedding\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu2 start to init embedding\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu7 start to init embedding\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu5 start to init embedding\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu3 start to init embedding\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu4 start to init embedding\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu0 init embedding done\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu1 init embedding done\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu2 init embedding done\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu7 init embedding done\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu6 init embedding done\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu5 init embedding done\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu4 init embedding done\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: gpu3 init embedding done\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Starting AUC NCCL warm-up\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Warm-up done\n",
      "===================================================Model Summary===================================================\n",
      "label                                   Dense                         Sparse                        \n",
      "label                                   dense                          data1                         \n",
      "(None, 1)                               (None, 9)                               \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "Layer Type                              Input Name                    Output Name                   Output Shape                  \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "DistributedSlotSparseEmbeddingHash      data1                         sparse_embedding1             (None, 9, 128)                \n",
      "InnerProduct                            dense                         fc1                           (None, 512)                   \n",
      "ReLU                                    fc1                           relu1                         (None, 512)                   \n",
      "InnerProduct                            relu1                         fc2                           (None, 256)                   \n",
      "ReLU                                    fc2                           relu2                         (None, 256)                   \n",
      "InnerProduct                            relu2                         fc3                           (None, 128)                   \n",
      "ReLU                                    fc3                           relu3                         (None, 128)                   \n",
      "Interaction                             relu3,sparse_embedding1       interaction1                  (None, 174)                   \n",
      "InnerProduct                            interaction1                  fc4                           (None, 1024)                  \n",
      "ReLU                                    fc4                           relu4                         (None, 1024)                  \n",
      "InnerProduct                            relu4                         fc5                           (None, 1024)                  \n",
      "ReLU                                    fc5                           relu5                         (None, 1024)                  \n",
      "InnerProduct                            relu5                         fc6                           (None, 512)                   \n",
      "ReLU                                    fc6                           relu6                         (None, 512)                   \n",
      "InnerProduct                            relu6                         fc7                           (None, 256)                   \n",
      "ReLU                                    fc7                           relu7                         (None, 256)                   \n",
      "InnerProduct                            relu7                         fc8                           (None, 1)                     \n",
      "BinaryCrossEntropyLoss                  fc8,label                     loss                                                        \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Save the model graph to dlrm_ecommerce.json successfully\n",
      "=====================================================Model Fit=====================================================\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Use non-epoch mode with number of iterations: 12000\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Training batchsize: 16384, evaluation batchsize: 16384\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Evaluation interval: 3000, snapshot interval: 10000\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Sparse embedding trainable: 1, dense network trainable: 1\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Use mixed precision: 0, scaler: 1, use cuda graph: -875196854\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: lr: 0.100000, warmup_steps: 8000, decay_start: 48000, decay_steps: 24000, decay_power: 2.000000, end_lr: 0.000000\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Training source file: ./nvtabular_temp/output/train/_file_list.txt\n",
      "[HUGECTR][23:18:12][INFO][RANK0]: Evaluation source file: ./nvtabular_temp/output/valid/_file_list.txt\n",
      "[HUGECTR][23:18:20][INFO][RANK0]: Iter: 1000 Time(1000 iters): 8.477706s Loss: 0.654302 lr:0.012512\n",
      "[HUGECTR][23:18:29][INFO][RANK0]: Iter: 2000 Time(1000 iters): 8.461642s Loss: 0.537260 lr:0.025013\n",
      "[HUGECTR][23:18:37][INFO][RANK0]: Iter: 3000 Time(1000 iters): 8.473848s Loss: 0.523659 lr:0.037512\n",
      "[HUGECTR][23:18:47][INFO][RANK0]: Evaluation, AUC: 0.652278\n",
      "[HUGECTR][23:18:47][INFO][RANK0]: Eval Time for 2720 iters: 9.794543s\n",
      "[HUGECTR][23:18:56][INFO][RANK0]: Iter: 4000 Time(1000 iters): 18.361339s Loss: 0.521578 lr:0.050012\n",
      "[HUGECTR][23:19:04][INFO][RANK0]: Iter: 5000 Time(1000 iters): 8.492043s Loss: 0.515692 lr:0.062513\n",
      "[HUGECTR][23:19:13][INFO][RANK0]: Iter: 6000 Time(1000 iters): 8.491605s Loss: 0.518826 lr:0.075013\n",
      "[HUGECTR][23:19:22][INFO][RANK0]: Evaluation, AUC: 0.650539\n",
      "[HUGECTR][23:19:22][INFO][RANK0]: Eval Time for 2720 iters: 9.814989s\n",
      "[HUGECTR][23:19:31][INFO][RANK0]: Iter: 7000 Time(1000 iters): 18.332924s Loss: 0.511855 lr:0.087513\n",
      "[HUGECTR][23:19:39][INFO][RANK0]: Iter: 8000 Time(1000 iters): 8.488666s Loss: 0.515189 lr:0.100000\n",
      "[HUGECTR][23:19:48][INFO][RANK0]: Iter: 9000 Time(1000 iters): 8.455840s Loss: 0.513654 lr:0.100000\n",
      "[HUGECTR][23:19:58][INFO][RANK0]: Evaluation, AUC: 0.645823\n",
      "[HUGECTR][23:19:58][INFO][RANK0]: Eval Time for 2720 iters: 9.750920s\n",
      "[HUGECTR][23:20:06][INFO][RANK0]: Iter: 10000 Time(1000 iters): 18.285750s Loss: 0.518827 lr:0.100000\n",
      "[HUGECTR][23:20:15][INFO][RANK0]: Rank0: Write hash table to file\n",
      "[HUGECTR][23:21:26][INFO][RANK0]: Dumping sparse weights to files, successful\n",
      "[HUGECTR][23:21:26][INFO][RANK0]: Dumping sparse optimzer states to files, successful\n",
      "[HUGECTR][23:21:26][INFO][RANK0]: Dumping dense weights to file, successful\n",
      "[HUGECTR][23:21:26][INFO][RANK0]: Dumping dense optimizer states to file, successful\n",
      "[HUGECTR][23:21:26][INFO][RANK0]: Dumping untrainable weights to file, successful\n",
      "[HUGECTR][23:21:35][INFO][RANK0]: Iter: 11000 Time(1000 iters): 88.781702s Loss: 0.511783 lr:0.100000\n",
      "[HUGECTR][23:21:43][INFO][RANK0]: Finish 12000 iterations with batchsize: 16384 in 211.67s.\n"
     ]
    }
   ],
   "source": [
    "!python3 hugectr_dlrm_ecommerce.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HugeCTR inference\n",
    "\n",
    "In this section, we read the test dataset and compute the AUC value. \n",
    "\n",
    "We will utilize the saved model graph in JSON format for inference."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare the inference session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from hugectr.inference import InferenceParams, CreateInferenceSession\n",
    "from mpi4py import MPI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[HUGECTR][23:21:46][INFO][RANK0]: default_emb_vec_value is not specified using default: 0.000000\n",
      "[HUGECTR][23:21:46][INFO][RANK0]: Created parallel (16 partitions) blank database backend in local memory!\n",
      "[HUGECTR][23:22:51][INFO][RANK0]: ParallelLocalMemory backend. Table: dlrm#0. Inserted 33653303 / 33653303 pairs.\n",
      "[HUGECTR][23:22:51][INFO][RANK0]: Cached 0.000000 * 33653303 embeddings in CPU memory database!\n",
      "[HUGECTR][23:22:52][INFO][RANK0]: Create embedding cache in device 0.\n",
      "[HUGECTR][23:22:53][INFO][RANK0]: create_refreshspace2\n",
      "[HUGECTR][23:22:53][INFO][RANK0]: Global seed is 813179416\n",
      "[HUGECTR][23:22:53][INFO][RANK0]: Device to NUMA mapping:\n",
      "  GPU 0 ->  node 0\n",
      "\n",
      "[HUGECTR][23:22:54][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.\n",
      "[HUGECTR][23:22:54][INFO][RANK0]: Start all2all warmup\n",
      "[HUGECTR][23:22:54][INFO][RANK0]: End all2all warmup\n",
      "[HUGECTR][23:22:54][INFO][RANK0]: Use mixed precision: 0\n",
      "[HUGECTR][23:22:54][INFO][RANK0]: start create embedding for inference\n",
      "[HUGECTR][23:22:54][INFO][RANK0]: sparse_input name data1\n",
      "[HUGECTR][23:22:54][INFO][RANK0]: create embedding for inference success\n",
      "[HUGECTR][23:22:54][INFO][RANK0]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer\n"
     ]
    }
   ],
   "source": [
    "# create inference session\n",
    "inference_params = InferenceParams(model_name = \"dlrm\",\n",
    "                              max_batchsize = 4096,\n",
    "                              hit_rate_threshold = 0.6,\n",
    "                              dense_model_file = \"./_dense_10000.model\",\n",
    "                              sparse_model_files = [\"./0_sparse_10000.model\"],\n",
    "                              device_id = 0,\n",
    "                              use_gpu_embedding_cache = True,\n",
    "                              cache_size_percentage = 0.2,\n",
    "                              i64_input_key = True)\n",
    "inference_session = CreateInferenceSession(\"dlrm_ecommerce.json\", inference_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reading and preparing the data\n",
    "\n",
    "First, we read the NVTabular processed data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ts_hour_user_id_brand_cross</th>\n",
       "      <th>ts_weekday_user_id_brand_cross</th>\n",
       "      <th>cat_0_user_id_brand_cross</th>\n",
       "      <th>cat_1_user_id_brand_cross</th>\n",
       "      <th>cat_2_user_id_brand_cross</th>\n",
       "      <th>product_id_user_id_cross</th>\n",
       "      <th>brand_user_id_cross</th>\n",
       "      <th>ts_hour_user_id_cross</th>\n",
       "      <th>ts_minute_user_id_cross</th>\n",
       "      <th>price</th>\n",
       "      <th>ts_weekday</th>\n",
       "      <th>ts_day</th>\n",
       "      <th>ts_month</th>\n",
       "      <th>TE_brand_target_TE</th>\n",
       "      <th>TE_user_id_target_TE</th>\n",
       "      <th>TE_product_id_target_TE</th>\n",
       "      <th>TE_cat_2_target_TE</th>\n",
       "      <th>TE_ts_weekday_ts_day_target_TE</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.823432</td>\n",
       "      <td>-1.490397</td>\n",
       "      <td>1.272925</td>\n",
       "      <td>-0.270035</td>\n",
       "      <td>-1.287369</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>-1.285848</td>\n",
       "      <td>0.445558</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.627107</td>\n",
       "      <td>-0.996241</td>\n",
       "      <td>0.581269</td>\n",
       "      <td>-0.270035</td>\n",
       "      <td>0.306095</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.339375</td>\n",
       "      <td>0.352659</td>\n",
       "      <td>0.396743</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.299060</td>\n",
       "      <td>-0.996241</td>\n",
       "      <td>1.388201</td>\n",
       "      <td>-0.270035</td>\n",
       "      <td>0.364552</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.443069</td>\n",
       "      <td>0.459563</td>\n",
       "      <td>0.433425</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.037364</td>\n",
       "      <td>1.474540</td>\n",
       "      <td>-0.456215</td>\n",
       "      <td>-0.270035</td>\n",
       "      <td>0.466595</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.431157</td>\n",
       "      <td>0.460040</td>\n",
       "      <td>0.407608</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.704362</td>\n",
       "      <td>-0.502085</td>\n",
       "      <td>-0.917319</td>\n",
       "      <td>-0.270035</td>\n",
       "      <td>0.218853</td>\n",
       "      <td>0.390281</td>\n",
       "      <td>0.274339</td>\n",
       "      <td>-1.285741</td>\n",
       "      <td>0.428683</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   ts_hour_user_id_brand_cross  ts_weekday_user_id_brand_cross  \\\n",
       "0                            0                               0   \n",
       "1                            0                               0   \n",
       "2                            0                               0   \n",
       "3                            0                               0   \n",
       "4                            0                               0   \n",
       "\n",
       "   cat_0_user_id_brand_cross  cat_1_user_id_brand_cross  \\\n",
       "0                          0                          0   \n",
       "1                          0                          0   \n",
       "2                          0                          0   \n",
       "3                          0                          0   \n",
       "4                          0                          0   \n",
       "\n",
       "   cat_2_user_id_brand_cross  product_id_user_id_cross  brand_user_id_cross  \\\n",
       "0                          0                         0                    0   \n",
       "1                          0                         0                    0   \n",
       "2                          0                         0                    0   \n",
       "3                          0                         0                    0   \n",
       "4                          0                         0                    0   \n",
       "\n",
       "   ts_hour_user_id_cross  ts_minute_user_id_cross     price  ts_weekday  \\\n",
       "0                      0                        0 -0.823432   -1.490397   \n",
       "1                      0                        0 -0.627107   -0.996241   \n",
       "2                      0                        0 -0.299060   -0.996241   \n",
       "3                      0                        0 -0.037364    1.474540   \n",
       "4                      0                        0 -0.704362   -0.502085   \n",
       "\n",
       "     ts_day  ts_month  TE_brand_target_TE  TE_user_id_target_TE  \\\n",
       "0  1.272925 -0.270035           -1.287369              0.390281   \n",
       "1  0.581269 -0.270035            0.306095              0.390281   \n",
       "2  1.388201 -0.270035            0.364552              0.390281   \n",
       "3 -0.456215 -0.270035            0.466595              0.390281   \n",
       "4 -0.917319 -0.270035            0.218853              0.390281   \n",
       "\n",
       "   TE_product_id_target_TE  TE_cat_2_target_TE  \\\n",
       "0                 0.390281           -1.285848   \n",
       "1                 0.339375            0.352659   \n",
       "2                 0.443069            0.459563   \n",
       "3                 0.431157            0.460040   \n",
       "4                 0.274339           -1.285741   \n",
       "\n",
       "   TE_ts_weekday_ts_day_target_TE  target  \n",
       "0                        0.445558     0.0  \n",
       "1                        0.396743     0.0  \n",
       "2                        0.433425     0.0  \n",
       "3                        0.407608     1.0  \n",
       "4                        0.428683     0.0  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "nvtdata_test = pd.read_parquet('./nvtabular_temp/output/test/part_0.parquet')\n",
    "nvtdata_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "con_feats = ['price',\n",
    " 'ts_weekday',\n",
    " 'ts_day',\n",
    " 'ts_month',\n",
    " 'TE_brand_target_TE',\n",
    " 'TE_user_id_target_TE',\n",
    " 'TE_product_id_target_TE',\n",
    " 'TE_cat_2_target_TE',\n",
    " 'TE_ts_weekday_ts_day_target_TE']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_feats = ['ts_hour_user_id_brand_cross',\n",
    " 'ts_weekday_user_id_brand_cross',\n",
    " 'cat_0_user_id_brand_cross',\n",
    " 'cat_1_user_id_brand_cross',\n",
    " 'cat_2_user_id_brand_cross',\n",
    " 'product_id_user_id_cross',\n",
    " 'brand_user_id_cross',\n",
    " 'ts_hour_user_id_cross',\n",
    " 'ts_minute_user_id_cross']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_size = [4427037, 3961156, 2877223, 2890639, 2159304, 4398425, 3009092, 3999369, 5931061]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Converting data to CSR format\n",
    "\n",
    "HugeCTR expects data in CSR format for inference. One important thing to note is that NVTabular requires categorical variables to occupy different integer ranges. For example, if there are 10 users and 10 items, then the users should be encoded in the 0-9 range, while items should be in the 10-19 range. NVTabular encodes both users and items in the 0-9 ranges.\n",
    "\n",
    "For this reason, we need to shift the keys of the categorical variable produced by NVTabular to comply with HugeCTR."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "shift = np.insert(np.cumsum(emb_size), 0, 0)[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_data = nvtdata_test[cat_feats].values + shift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "dense_data = nvtdata_test[con_feats].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def infer_batch(inference_session, dense_data_batch, cat_data_batch):\n",
    "    dense_features = list(dense_data_batch.flatten())\n",
    "    embedding_columns = list(cat_data_batch.flatten())\n",
    "    row_ptrs= list(range(0,len(embedding_columns)+1))\n",
    "    output = inference_session.predict(dense_features, embedding_columns, row_ptrs)\n",
    "    return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to carry out inference on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 4096\n",
    "num_batches = (len(dense_data) // batch_size) + 1\n",
    "batch_idx = np.array_split(np.arange(len(dense_data)), num_batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (4.62.3)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
      "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 21.3.1 is available.\n",
      "You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "labels = []\n",
    "for batch_id in tqdm(batch_idx):\n",
    "    dense_data_batch = dense_data[batch_id]\n",
    "    cat_data_batch = cat_data[batch_id]\n",
    "    results = infer_batch(inference_session, dense_data_batch, cat_data_batch)\n",
    "    labels.extend(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2772486"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing the test AUC value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "ground_truth = nvtdata_test['target'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5565971533171648"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "roc_auc_score(ground_truth, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "In this notebook, we have walked you through the process of preprocessing the data, train a DLRM model with HugeCTR, then carrying out inference with the HugeCTR Python interface. Try this workflow on your data and let us know your feedback.\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
